In [1]:
pip install torch torchvision matplotlib pillow




In [6]:
import os
import random
from PIL import Image
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, Dataset
import torchvision.transforms as transforms

# Siamese Network Definition
class SiameseNetwork(nn.Module):
    def __init__(self):
        super(SiameseNetwork, self).__init__()
        self.cnn = nn.Sequential(
            nn.Conv2d(3, 64, kernel_size=3, stride=1, padding=1),
            nn.ReLU(),
            nn.MaxPool2d(2, 2),
            nn.Conv2d(64, 128, kernel_size=3, stride=1, padding=1),
            nn.ReLU(),
            nn.MaxPool2d(2, 2)
        )
        self.fc = nn.Sequential(
            nn.Linear(128 * 32 * 32, 256),
            nn.ReLU(),
            nn.Linear(256, 128),
            nn.ReLU()
        )

    def forward(self, x):
        x = self.cnn(x)
        x = x.view(x.size(0), -1)  # Flatten
        x = self.fc(x)
        return x

# Contrastive Loss Function
class ContrastiveLoss(nn.Module):
    def __init__(self, margin=2.0):
        super(ContrastiveLoss, self).__init__()
        self.margin = margin

    def forward(self, output1, output2, label):
        euclidean_distance = nn.functional.pairwise_distance(output1, output2)
        loss = torch.mean(
            (1 - label) * torch.pow(euclidean_distance, 2) +
            (label) * torch.pow(torch.clamp(self.margin - euclidean_distance, min=0.0), 2)
        )
        return loss

# Dataset for Image Pairs
class ImagePairDataset(Dataset):
    def __init__(self, root_dir, transform=None):
        self.root_dir = root_dir
        self.transform = transform
        self.classes = os.listdir(root_dir)
        self.image_paths = {cls: [os.path.join(root_dir, cls, img) for img in os.listdir(os.path.join(root_dir, cls))]
                            for cls in self.classes}

    def __getitem__(self, index):
        class1 = random.choice(self.classes)
        img1_path = random.choice(self.image_paths[class1])
        img1 = Image.open(img1_path).convert("RGB")

        should_get_same_class = random.randint(0, 1)
        if should_get_same_class:
            img2_path = random.choice(self.image_paths[class1])
            label = 1
        else:
            class2 = random.choice([cls for cls in self.classes if cls != class1])
            img2_path = random.choice(self.image_paths[class2])
            label = 0

        img2 = Image.open(img2_path).convert("RGB")

        if self.transform:
            img1 = self.transform(img1)
            img2 = self.transform(img2)

        return img1, img2, torch.tensor(label, dtype=torch.float32)

    def __len__(self):
        return sum(len(imgs) for imgs in self.image_paths.values())

# Training Function
def train_model(dataset_path, epochs=5, batch_size=32, learning_rate=0.001):
    transform = transforms.Compose([
        transforms.Resize((128, 128)),
        transforms.ToTensor(),
        transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
    ])

    dataset = ImagePairDataset(root_dir=dataset_path, transform=transform)
    data_loader = DataLoader(dataset, batch_size=batch_size, shuffle=True)

    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    model = SiameseNetwork().to(device)
    criterion = ContrastiveLoss()
    optimizer = optim.Adam(model.parameters(), lr=learning_rate)

    for epoch in range(epochs):
        model.train()
        epoch_loss = 0
        for img1, img2, label in data_loader:
            img1, img2, label = img1.to(device), img2.to(device), label.to(device)
            optimizer.zero_grad()
            output1 = model(img1)
            output2 = model(img2)
            loss = criterion(output1, output2, label)
            loss.backward()
            optimizer.step()
            epoch_loss += loss.item()
        print(f"Epoch [{epoch+1}/{epochs}], Loss: {epoch_loss/len(data_loader):.4f}")

    return model

# Testing Function with Class Name Display
def test_model(model, img1_path, img2_path, class_names):
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

    transform = transforms.Compose([
        transforms.Resize((128, 128)),
        transforms.ToTensor(),
        transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
    ])

    model = model.to(device)
    model.eval()

    # Load images and their class names
    img1 = Image.open(img1_path).convert("RGB")
    img2 = Image.open(img2_path).convert("RGB")

    img1_class = img1_path.split("/")[-2]
    img2_class = img2_path.split("/")[-2]

    img1 = transform(img1).unsqueeze(0).to(device)
    img2 = transform(img2).unsqueeze(0).to(device)

    with torch.no_grad():
        output1 = model(img1)
        output2 = model(img2)
        euclidean_distance = nn.functional.pairwise_distance(output1, output2)

    print(f"Class of Image 1: {img1_class}")
    print(f"Class of Image 2: {img2_class}")
    print(f"Euclidean Distance: {euclidean_distance.item():.4f}")

    threshold = 0.5  # Set the threshold based on testing (experiment with this value)
    if euclidean_distance.item() < threshold:
        print("The images are likely from the same class.")
    else:
        print("The images are likely from different classes.")

# Usage Example
# Training
trained_model = train_model(
    dataset_path="/content/drive/MyDrive/few shot learning",
    epochs=20,
    batch_size=32
)

# Testing
test_model(
    model=trained_model,
    img1_path="/content/drive/MyDrive/few shot learning/cat/cat 10.jpeg",
    img2_path="/content/drive/MyDrive/few shot learning/dog/img 27.jpeg",
    class_names=["cat", "dog"]
)


Epoch [1/20], Loss: 122.6227
Epoch [2/20], Loss: 11.4142
Epoch [3/20], Loss: 3.0808
Epoch [4/20], Loss: 1.3770
Epoch [5/20], Loss: 1.1474
Epoch [6/20], Loss: 1.6036
Epoch [7/20], Loss: 1.4107
Epoch [8/20], Loss: 1.5781
Epoch [9/20], Loss: 1.5887
Epoch [10/20], Loss: 1.1281
Epoch [11/20], Loss: 1.2526
Epoch [12/20], Loss: 1.3610
Epoch [13/20], Loss: 1.2847
Epoch [14/20], Loss: 1.5405
Epoch [15/20], Loss: 1.4908
Epoch [16/20], Loss: 1.4924
Epoch [17/20], Loss: 1.8065
Epoch [18/20], Loss: 1.5692
Epoch [19/20], Loss: 1.8273
Epoch [20/20], Loss: 1.6245
Class of Image 1: cat
Class of Image 2: dog
Euclidean Distance: 1.1389
The images are likely from different classes.
