In [1]:
import torch
import torch.nn as nn
import os
import numpy as np
import pathlib as path
import matplotlib.pyplot as plt
import copy
import torch.nn.functional as f
from tqdm import tqdm
from PIL import Image
from torch.optim import Adam 
from torch.nn import TripletMarginLoss
from torchvision import transforms
from torchinfo import summary
from torchvision.models import resnet18
from torchvision.models import ResNet18_Weights
from torch.utils.data import Dataset, DataLoader, random_split
from torch.optim.lr_scheduler import ReduceLROnPlateau
from time import time
from torch.nn.functional import pairwise_distance

In [None]:
class Earlystopper:
  def __init__(self, min_delta = 0.0, patience = 1):
    
    self.min_delta = min_delta
    self.patience = patience
    self.best_validation = None
    self.best_model_state = None
    self.counter = 0
    
  def earlystop(self, validation_loss, model):
    
    if self.best_validation is None:
      self.best_validation = validation_loss
      self.best_model_state = copy.deepcopy(model.state_dict())
      
    elif validation_loss <= self.best_validation - self.min_delta:
      
      self.best_validation = validation_loss
      self.best_model_state = copy.deepcopy(model.state_dict())
      self.counter = 0
    
    else:
      self.counter += 1
      
      print(f"Earlystop: {self.counter}/{self.patience}")
      if self.counter >= self.patience:
        return True
    
    return False
       
  def restore_best_weight(self, model):
    if self.best_model_state is not None:
      model.load_state_dict(self.best_model_state)
      print("Restored Best Weight")

In [None]:
def Accuracy(anchor, positive, negative, margin=0.2):
   
    d_ap = torch.norm(anchor - positive, p=2, dim=1)
    d_an = torch.norm(anchor - negative, p=2, dim=1)
    
    correct = (d_ap + margin < d_an).float()
    accuracy = correct.mean().item()
    
    return accuracy


In [None]:
class SiameseDataset(Dataset):
  def __init__(self, image_dir=None, transform = None):
    
    self.image_dir = image_dir
    self.transform = transforms.Compose([
        transforms.RandomResizedCrop(224, scale=(0.5, 1.0), ratio=(0.9, 1.1)),
        transforms.RandomHorizontalFlip(p=0.5),
        transforms.RandomRotation(13),
        transforms.ColorJitter(brightness=0.5, contrast=0.5, saturation=0.5, hue=0.02),
        transforms.RandomAffine(degrees=20, translate=(0.05, 0.05)),
        transforms.RandomGrayscale(p=0.05),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.5], std=[0.5]),
      ])
    
    self.image = self.prepare_triplet_data(self.image_dir)
  
  @classmethod
  def prepare_triplet_data(cls, image_dir_path):
    
    if image_dir_path:
    
      image_path = []
      
      for idx, folder in enumerate(os.listdir(image_dir_path)):
        Image_dir = os.path.join(image_dir_path, folder)
        if os.path.isdir(Image_dir):
          for file in os.listdir(Image_dir):
            if file.lower().endswith(("jpg","png")):
              image_full_path = os.path.join(Image_dir, file)
              image_path.append((image_full_path, idx))
            
    return image_path
        
  def __len__(self):
    return len(self.image)
  
  def __getitem__(self, index):
    
    image, label = self.image[index]
    try:
      image = Image.open(image).convert("RGB")
    except Exception as e:
      return None, None
    if image is None:
      return None, None
    image = self.transform(image)
    if torch.isnan(image).any():
      print(f"Nan Detected in image at index {index}")
      return None, None
      
    return image, label

In [None]:
image_path = "DatasetPath"
dataset = SiameseDataset(image_dir= image_path)

In [None]:
def collate_fn(batch):
  batch = [item for item in batch if item[0] is not None]
  if batch is None:
    return None
  image, label = zip(*batch)
  return torch.stack(image), torch.tensor(label)

In [None]:
train_size = int(0.7 * len(dataset))
test_size = len(dataset) - train_size

generator = torch.Generator()

train_data, test_data = random_split(dataset, [train_size, test_size], generator)

Train_loader = DataLoader(train_data, batch_size=64, shuffle= True, drop_last= True, collate_fn=collate_fn)
Test_loader = DataLoader(test_data, batch_size=64, shuffle=True, drop_last= True, collate_fn=collate_fn)


In [None]:
class Siamese_network(nn.Module):
  def __init__(self):
    super(Siamese_network, self).__init__()
    
    base_model = resnet18(weights = ResNet18_Weights.DEFAULT)
    
    for param in base_model.parameters():
      param.requires_grad = False
    
    self.feature = nn.Sequential(*list(base_model.children())[:-1])
    
    self.fc = nn.Sequential(
      nn.Dropout(0.2, inplace= True),
      nn.Linear(512, 128)
    )
    
  def forward(self, x1, x2):
    x1 = self.forward_once(x1)
    x2 = self.forward_once(x2)
    
    dist = pairwise_distance(x1, x2)
    
    return dist
    
  def forward_once(self, x):
    
    x = self.feature(x).view(x.size(0), -1)
    return f.normalize(self.fc(x), 2, 1)

In [None]:
def semi_hard_batching(embedding, label, margin = 0.3):
  
  batch_size = embedding.size(0)
  dist = torch.cdist(embedding, embedding ,p=2)
  
  anchor, positive, negative = [], [], []
  
  for i in range(batch_size):
    
    mask_positive = (label == label[i])
    mask_positive[i] = False
    mask_negative = (label != label[i])
    
    if not mask_positive.any() or not mask_negative.any():
      continue
    
    pos_id = torch.where(mask_positive)[0]
    neg_id = torch.where(mask_negative)[0]
    
    d_positive = dist[i, pos_id]
    d_negative = dist[i, neg_id]
    
    mask = (d_negative.unsqueeze(1) > d_positive.unsqueeze(0)) & (d_negative.unsqueeze(1) < d_positive.unsqueeze(0) + margin)
    
    neg_i, pos_i = torch.nonzero(mask, as_tuple=True)
    for ni, pi in zip(neg_i.tolist(), pos_i.tolist()):
        anchor.append(i)
        positive.append(pos_id[pi].item())
        negative.append(neg_id[ni].item())

  if not anchor:
      empty = torch.empty(0, dtype=torch.long)
      return empty, empty, empty

  return (
      torch.tensor(anchor, dtype=torch.long),
      torch.tensor(positive, dtype=torch.long),
      torch.tensor(negative, dtype=torch.long),
  )

In [None]:
device = 'cuda' if torch.cuda.is_available() else 'cpu'

Epochs = 50
model = Siamese_network().to(device)
criterion = TripletMarginLoss(margin= 0.5)
optimizer = Adam(model.parameters(), lr = 1e-5)
scheduler = ReduceLROnPlateau(optimizer, mode= 'min', factor= 0.3, patience= 5)
earlystop = Earlystopper(min_delta= 0.01, patience= 10)

In [None]:
acc_stat_A = []
acc_stat_B = []

loss_stat_A = []
loss_stat_B = []

for Epoch in range(Epochs):
  model.train()
  
  running_loss = 0
  running_acc = 0
  
  for x_train in tqdm(Train_loader, total= len(Train_loader), desc = f"Epoch: {Epoch + 1} / {Epochs} - Training:", leave = False):
    
    img, label = [data.to(device) for data in x_train]
    
    embedding = model.forward_once(img)
    
    a, p, n = semi_hard_batching(embedding, label)
    
    if a.numel() == 0:
      continue
    
    a = embedding[a]
    p = embedding[p]
    n = embedding[n]
    
    dist_A = pairwise_distance(a, p)
    dist_B = pairwise_distance(a, n)
    
    loss = criterion(a, p, n)
    acc = Accuracy(a, p, n)
    
    running_loss += loss.item()
    running_acc += acc
    
    optimizer.zero_grad()
    loss.backward()
    torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
    optimizer.step()
    
    
  Train_loss = running_loss / len(Train_loader)
  Train_acc = running_acc / len(Train_loader)
  loss_stat_A.append(Train_loss)
  acc_stat_A.append(Train_acc)

     
  model.eval()
  with torch.inference_mode():
    
    test_running_loss = 0
    test_running_acc = 0
    
    for x_test in tqdm(Test_loader, total = len(Test_loader), desc= f"Epoch: {Epoch + 1} / {Epochs} - Testing", leave= False):

      img, label= [data.to(device) for data in x_test]
      embedding = model.forward_once(img)
      
      a, p, n = semi_hard_batching(embedding, label)
      
      if a.numel() == 0:
        continue
      
      a = embedding[a]
      p = embedding[p]
      n = embedding[n]
      
      test_loss = criterion(a, p ,n)
      test_acc = Accuracy(a, p, n)
      
      test_running_loss += test_loss.item()
      test_running_acc += test_acc
      
  Test_loss = test_running_loss / len(Test_loader)
  Test_acc = test_running_acc / len(Test_loader)
  loss_stat_B.append(Test_loss)
  acc_stat_B.append(test_acc)
  
  scheduler.step(Test_loss)
  
  if earlystop.earlystop(Test_loss, model):
    earlystop.restore_best_weight(model)
    break
  
  print(f"Epoch: {Epoch+1}\t|Train Acc: {Train_acc:.02f}%\t|Train Loss: {Train_loss:.02f}%\t|Val Acc: {Test_acc:.02f}%\t|Val Loss: {Test_loss:.02f}%")

In [None]:
def visualize_distance(result,device):
  
  plt.figure(figsize= (9, 3))
  label = ["Accuracy", "Valid Accuracy","Loss", "Valid Loss"]
  
  for i in range(4):
    
    if i < 1:
      plt.subplot(2, 2, i + 1)
      plt.title(label[i])
      plt.plot(result[i])
    else:
      plt.subplot(2, 2, i + 1)
      plt.title(label[i])
      plt.plot(result[i])
  
  plt.tight_layout()
  plt.show()
  
visualize_distance((acc_stat_A, acc_stat_B, loss_stat_A, loss_stat_B), device)
  

In [None]:
def save_model(model):

  model_path = "SavePath"
  model_name = "modelName.pth"
  model_save_path = os.path.join(model_path, model_name)

  torch.save(model.state_dict(), model_save_path)


save_model(model)