In [None]:
import torch
import random
import pandas as pd
import torch.nn as nn
from PIL import Image
from tqdm import tqdm
import torch.nn.functional as F
from torchvision import transforms
from efficientnet_pytorch import EfficientNet
from torch.utils.data import DataLoader, Dataset

transform = transforms.Compose([
    transforms.RandomResizedCrop(456),
    transforms.RandomHorizontalFlip(),
    transforms.RandomRotation(10),
    transforms.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2, hue=0.2),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])

# Basic transform for consistency in pair generation
basic_transform = transforms.Compose([
    transforms.Resize(456),
    transforms.CenterCrop(456),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])


def load_and_preprocess_image(path):
    image = Image.open(path).convert('RGB')
    image = transform(image)
    return image

In [None]:
import torchvision.models as models


class EfficientNetEmbedding(nn.Module):
    def __init__(self, embedding_dim=256):  # 512, 1024, 2048
        super().__init__()
        self.base_model = models.efficientnet_b5(weights=models.EfficientNet_B5_Weights, include_top=False)
        self.fc1 = nn.Linear(self.base_model._fc.in_features, 512)
        self.relu = nn.ReLU()
        self.dropout = nn.Dropout(0.5)
        self.fc2 = nn.Linear(512, embedding_dim)

    def forward(self, x):
        x = self.base_model.extract_features(x)
        x = torch.flatten(x, 1)
        x = self.fc1(x)
        x = self.relu(x)
        x = self.dropout(x)
        x = self.fc2(x)
        return x


class SiameseNetwork(nn.Module):
    def __init__(self, embedding_dim=128):
        super().__init__()
        self.embedding_net = EfficientNetEmbedding(embedding_dim)

    def forward(self, x1, x2):
        output1 = self.embedding_net(x1)
        output2 = self.embedding_net(x2)
        return output1, output2


class NTXentLoss(nn.Module):
    def __init__(self, temperature=0.5):
        super(NTXentLoss, self).__init__()
        self.temperature = temperature
        self.cosine_similarity = nn.CosineSimilarity(dim=-1)

    def forward(self, z_i, z_j):
        batch_size = z_i.size(0)
        z = torch.cat([z_i, z_j], dim=0)
        sim_matrix = self.cosine_similarity(z.unsqueeze(1), z.unsqueeze(0)) / self.temperature
        labels = torch.arange(batch_size).to(z.device)
        labels = torch.cat([labels, labels], dim=0)
        loss = F.cross_entropy(sim_matrix, labels)
        return loss


class SiameseDataset(Dataset):
    def __init__(self, image_paths, transform=None, basic_transform=None):
        self.image_paths = image_paths
        self.transform = transform
        self.basic_transform = basic_transform

    def __len__(self):
        return len(self.image_paths)

    def __getitem__(self, idx):
        img1_path = self.image_paths[idx]
        img1 = Image.open(img1_path).convert('RGB')

        # Create a positive pair
        img1_aug = self.transform(img1)
        img2_aug = self.transform(img1)

        # Create a negative pair
        neg_idx = random.randint(0, len(self.image_paths) - 1)
        while neg_idx == idx:
            neg_idx = random.randint(0, len(self.image_paths) - 1)
        img2_path = self.image_paths[neg_idx]
        img2 = Image.open(img2_path).convert('RGB')

        if self.basic_transform:
            img2 = self.basic_transform(img2)

        return img1_aug, img2_aug, img2

In [None]:
def train_contrastive_model(dataloader, epochs=30, temperature=0.5, learning_rate=1e-3):
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    model = SiameseNetwork().to(device)
    criterion = NTXentLoss(temperature)
    optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)
    scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', factor=0.1, patience=5, verbose=True)

    for epoch in range(epochs):
        model.train()
        train_loss = 0.0
        num_batches = len(dataloader)
        prog_bar = tqdm(dataloader)

        for img1_aug, img2_aug, img2_neg in prog_bar:
            img1_aug, img2_aug, img2_neg = img1_aug.to(device), img2_aug.to(device), img2_neg.to(device)

            optimizer.zero_grad()
            output1, output2 = model(img1_aug, img2_aug)
            loss_pos = criterion(output1, output2)

            output1_neg, output2_neg = model(img1_aug, img2_neg)
            loss_neg = criterion(output1_neg, output2_neg)

            loss = loss_pos + loss_neg
            loss.backward()
            optimizer.step()

            train_loss += loss.item()
            prog_bar.set_description(f'Epoch {epoch + 1}/{epochs} Train Loss: {train_loss / (prog_bar.n + 1):.4f}')

        avg_train_loss = train_loss / num_batches
        print(f'Epoch {epoch + 1} Train Loss: {avg_train_loss:.4f}')

        scheduler.step(avg_train_loss)

    return model

In [None]:
df = pd.read_csv('datasets/cropped_all_one_hot.csv')
image_paths = df['file_path'].tolist()

In [None]:
batch_size = 128
dataset = SiameseDataset(image_paths, transform=transform, basic_transform=basic_transform)
dataloader = DataLoader(dataset, shuffle=True, batch_size=batch_size)

In [None]:
trained_model = train_contrastive_model(dataloader, epochs=50, temperature=0.5, learning_rate=1e-3)