In [2]:
import os
import pandas as pd
from PIL import Image
import torch
from torch.utils.data import Dataset
from torchvision import transforms

# Defining the data loader for matched image pairs
class RetinaDataset(Dataset):
    def __init__(self, csv_file, root_dir, transform=None):
        self.annotations = pd.read_csv(csv_file)
        self.root_dir = root_dir
        self.transform = transform

    def __len__(self):
        return len(self.annotations)

    def __getitem__(self, index):
        left_img_path = os.path.join(self.root_dir, 'left', self.annotations.iloc[index, 0] + '.jpg')  # Assuming the extension is .jpg
        right_img_path = os.path.join(self.root_dir, 'right', self.annotations.iloc[index, 1] + '.jpg')  # Assuming the extension is .jpg
        
        left_image = Image.open(left_img_path).convert('RGB')
        right_image = Image.open(right_img_path).convert('RGB')
        
        if self.transform:
            left_image = self.transform(left_image)
            right_image = self.transform(right_image)

        # Return the paired images
        return (left_image, right_image)

# Example usage:
transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])

dataset = RetinaDataset(csv_file='train.csv', root_dir='train', transform=transform)

# Testing if the data loader works
sample_left_image, sample_right_image = dataset[0]


In [3]:
import torchvision.models as models

# Download the pretrained ResNet-18 model
resnet18 = models.resnet18(pretrained=True)
torch.save(resnet18.state_dict(), 'resnet18_weights.pth')




In [4]:
import torch.nn as nn
import torchvision.models as models

# Define the Siamese Network again
class SiameseNetwork(nn.Module):
    def __init__(self):
        super(SiameseNetwork, self).__init__()
        self.resnet18 = models.resnet18(pretrained=False)  # Pretrained set to False as we'll load provided weights
        
    def forward_one(self, x):
        x = self.resnet18(x)
        return x

    def forward(self, input1, input2):
        output1 = self.forward_one(input1)
        output2 = self.forward_one(input2)
        return output1, output2

# Instantiate the Siamese network
siamese_model = SiameseNetwork()

# Load the provided weights
weights_path = "resnet18_weights.pth"
siamese_model.resnet18.load_state_dict(torch.load(weights_path))

# Display the model architecture
siamese_model



SiameseNetwork(
  (resnet18): ResNet(
    (conv1): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
    (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (relu): ReLU(inplace=True)
    (maxpool): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
    (layer1): Sequential(
      (0): BasicBlock(
        (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
        (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (relu): ReLU(inplace=True)
        (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
        (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      )
      (1): BasicBlock(
        (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
        (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, t

In [5]:
device = torch.device("cpu")


In [8]:
# 1. 定义模型、损失函数和优化器
siamese_network = SiameseNetwork()
siamese_network.resnet18.load_state_dict(torch.load('resnet18_weights.pth'))
siamese_network = siamese_network.to(device)

criterion = ContrastiveLoss()
optimizer = torch.optim.Adam(siamese_network.parameters(), lr=0.001)



In [7]:
import torch.nn.functional as F

class ContrastiveLoss(nn.Module):
    def __init__(self, margin=2.0):
        super(ContrastiveLoss, self).__init__()
        self.margin = margin

    def forward(self, output1, output2, label):
        euclidean_distance = F.pairwise_distance(output1, output2)
        loss_contrastive = torch.mean((1-label) * torch.pow(euclidean_distance, 2) +
                                      (label) * torch.pow(torch.clamp(self.margin - euclidean_distance, min=0.0), 2))
        return loss_contrastive


In [9]:
from torchvision import transforms

# Define the data transformations
transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])

# Create an instance of the RetinaDataset
train_dataset = RetinaDataset(csv_file='train.csv', root_dir='train', transform=transform)


In [11]:
from torch.utils.data import DataLoader, random_split

train_size = int(0.8 * len(train_dataset))
val_size = len(train_dataset) - train_size
train_subset, val_subset = random_split(train_dataset, [train_size, val_size])

train_loader = DataLoader(train_subset, batch_size=32, shuffle=True)
val_loader = DataLoader(val_subset, batch_size=32, shuffle=False)


Collecting tqdm
  Obtaining dependency information for tqdm from https://files.pythonhosted.org/packages/00/e5/f12a80907d0884e6dff9c16d0c0114d81b8cd07dc3ae54c5e962cc83037e/tqdm-4.66.1-py3-none-any.whl.metadata
  Using cached tqdm-4.66.1-py3-none-any.whl.metadata (57 kB)
Using cached tqdm-4.66.1-py3-none-any.whl (78 kB)
Installing collected packages: tqdm
Successfully installed tqdm-4.66.1
Note: you may need to restart the kernel to use updated packages.


In [12]:
from tqdm import tqdm

def train_siamese_network(model, criterion, optimizer, train_loader, val_loader, epochs=10):
    for epoch in range(epochs):
        model.train()
        running_loss = 0.0
        # Wrap the train_loader with tqdm for the progress bar
        for (left_img, right_img) in tqdm(train_loader, desc=f"Epoch {epoch+1}/{epochs}"):
            optimizer.zero_grad()

            output1, output2 = model(left_img, right_img)
            # Assuming all pairs are from the same person and hence similar
            labels = torch.ones(left_img.size(0)).to(device) 
            loss = criterion(output1, output2, labels)
            loss.backward()
            optimizer.step()

            running_loss += loss.item()

        print(f"Loss: {running_loss/len(train_loader):.4f}")

        # Optional: Validation loop can be added here


        # Optional: Validation loop can be added here

# Define the loss and optimizer again
criterion = ContrastiveLoss().to(device)
optimizer = torch.optim.Adam(siamese_network.parameters(), lr=0.001)

# Train the network
train_siamese_network(siamese_network, criterion, optimizer, train_loader, val_loader, epochs=10)



Epoch 1/10:   2%|▏         | 1/50 [00:18<15:24, 18.88s/it]