In [1]:
import os
import sys
import random
import numpy as np
import matplotlib.pyplot as plt
from PIL import Image
import cv2


import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F


import torchvision
import torchvision.utils
import torchvision.models as models
import torchvision.datasets as datasets
import torchvision.transforms as transforms
from torch.utils.data import DataLoader, Dataset

import torchsummary
from vi_data import plot_image
from CustomLosses import TripletLossTorch

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')


In [2]:
Train_dir = "./Train/"
Test_dir = "./Test"

In [3]:
class ModelDataset(Dataset):
    
    def __init__(self, imgs, transform=None):
        self.imgs = imgs
        self.transform = transform or transforms.ToTensor()
        
        
    def __getitem__(self, index):
        anchor, anchor_label = self.imgs[index]

        positives = []
        negatives = []
        for img, label in self.imgs :
            if label == anchor_label   :
                positives.append(img)
            elif label != anchor_label   :
                negatives.append(img)
        positive = random.choice(positives)
        negative = random.choice(negatives)

        
        anc_img = Image.open(anchor).convert('RGB')
        pos_img = Image.open(positive).convert('RGB')
        neg_img = Image.open(negative).convert('RGB')
        
        if self.transform is not None:
            anc_img = self.transform(anc_img)
            pos_img = self.transform(pos_img)
            neg_img = self.transform(neg_img)
        
        return anc_img, pos_img, neg_img
    
    def __len__(self):
        return len(self.imgs)

In [4]:
train_tfms = transforms.Compose([
    transforms.Resize((128, 128)),
    transforms.RandomHorizontalFlip(),
    transforms.RandomRotation(90),
    transforms.ToTensor()
])

train_ds = datasets.ImageFolder(Train_dir)
train_ds = ModelDataset(train_ds.imgs, transform=train_tfms)
train_dl = DataLoader(train_ds, batch_size=16, shuffle=True, num_workers=0)

In [5]:
test_tfms = transforms.Compose([
    transforms.Resize((128, 128)),
    transforms.ToTensor()
])
test_ds = datasets.ImageFolder(Test_dir)
test_ds = ModelDataset(test_ds.imgs, transform=test_tfms)
test_dl = DataLoader(test_ds, batch_size=16, shuffle=True, num_workers=0)

In [6]:
class ResNET (torch.nn.Module):

    def __init__(self):
        super(ResNET, self).__init__()
        self.conv_layer1 = self.conv_layer(3, 16, 7, stride=2)
        self.conv_layer2 = self.conv_layer(16, 32, 3)
        self.res_layer1 = self.res_layer(32)
        self.conv_layer3 = self.conv_layer(32, 64, 3)
        self.res_layer2 = self.res_layer(64)
        self.conv_layer4 = self.conv_layer(64, 128, 3)
        self.res_layer3= self.res_layer(128)
        self.conv_layer5 = self.conv_layer(128, 256, 3)

        self.fc_layer1 = self.fully_connected(256, 192, dropout=True)
        self.fc_layer2 = self.fully_connected(192, 128)
        self.last_layer = self.last_fc(128, 96)
        


    def conv_layer(self, i_filter, o_filter, kernel, stride=1):
        layer = torch.nn.Sequential(
            torch.nn.Conv2d(in_channels=i_filter, out_channels=o_filter, kernel_size=kernel,  stride=stride),
            torch.nn.BatchNorm2d(o_filter),
            torch.nn.LeakyReLU(),
            torch.nn.Conv2d(in_channels=o_filter, out_channels=o_filter, kernel_size=3, padding=1, stride=2)
        )
        return layer

    def res_layer(self, i_filter, kernel=3, pad=1):
        layer = torch.nn.Sequential(
            torch.nn.Conv2d(in_channels=i_filter, out_channels=i_filter, kernel_size=kernel,  padding=pad),
            torch.nn.BatchNorm2d(i_filter),
            torch.nn.LeakyReLU(),
            torch.nn.Conv2d(in_channels=i_filter, out_channels=i_filter, kernel_size=1),
            torch.nn.BatchNorm2d(i_filter),
            torch.nn.LeakyReLU(),
        )
        return layer
    
    def fully_connected(self, input_n, output_n, dropout=False):
        if dropout :
            layer = torch.nn.Sequential(
                torch.nn.Dropout(p=0.2),
                torch.nn.Linear(in_features=input_n, out_features=output_n),
                torch.nn.LeakyReLU()
            )
        else :
            layer = torch.nn.Sequential(
                torch.nn.Linear(in_features=input_n, out_features=output_n),
                torch.nn.LeakyReLU()
            )
        return layer

    def last_fc(self, input_n, output_n) :
        layer = torch.nn.Sequential(
            torch.nn.Linear(in_features=input_n, out_features=output_n),
            
        )
        return layer


    def forward(self, inputs):
        out = self.conv_layer1(inputs)
        out = self.conv_layer2(out)
        out_ = self.res_layer1(out)
        out = torch.add(out, out_)

        out = self.conv_layer3(out)
        out_ = self.res_layer2(out)
        out = torch.add(out, out_)

        out = self.conv_layer4(out)
        out_ = self.res_layer3(out)
        out = torch.add(out, out_)

        out = self.conv_layer5(out)
        out = out.view(out.size(0), -1)

        out = self.fc_layer1(out)
        out = self.fc_layer2(out)
        out = self.last_layer(out)


        return out


In [7]:
model = ResNET()
model = model.to(device)
criterion = TripletLossTorch().to(device)
optimizer = optim.Adam(model.parameters(), lr=0.0002)
scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=5, gamma=0.5)

In [None]:
torchsummary.summary(model, ( 3, 128, 128))

In [None]:
import time
class TrainModel :
    def __init__(self, model, train_dl, criterion, optimizer, scheduler=None, device=torch.device("cuda"), valid_dl=None, metric=None):
        self.model = model
        self.criterion = criterion
        self.optimizer = optimizer
        self.scheduler = scheduler
        self.device = device
        self.dataloaders = {"Train" : train_dl}
        self.phases = ["Train"]
        self.loss_history = []
        if valid_dl is not None:
            self.dataloaders.update({"Valid": valid_dl})
            self.phases.append("Valid")
            self.val_loss_history = []
            if metric is not None :
                self.val_metric_history = []
        self.metric = metric
        if metric is not None :    
            self.metric_history = []
        
            

    def train_step(self, epochs):
        print(f"Phases: {self.phases}")
        for epoch in range(epochs):
            print(f'Epoch: {epoch+1}/{epochs}')
            print("-" * 10)
            
            for phase in self.phases:
                start = time.time()
                if phase == "Train":
                    self.model.train()
                else:
                    self.model.eval()
                running_loss = 0.0
                dl = self.dataloaders[phase]
                if phase == "Train":
                    for anchor, positive, negative in dl :
                        anchor = anchor.to(device)
                        positive = positive.to(device)
                        negative = negative.to(device)
                        anchor_out = self.model(anchor)
                        positive_out = self.model(positive)
                        negative_out = self.model(negative)
                        loss = self.criterion(anchor_out, positive_out, negative_out)
                        self.optimizer.zero_grad()
                        loss.backward()
                        self.optimizer.step()
                        running_loss = running_loss + (loss.item() * anchor.size(0))
                if phase == "Valid":
                    for anchor, positive, negative in dl :
                        with torch.no_grad():
                            anchor = anchor.to(device)
                            positive = positive.to(device)
                            negative = negative.to(device)
                            anchor_out = self.model(anchor)
                            positive_out = self.model(positive)
                            negative_out = self.model(negative)
                            loss = self.criterion(anchor_out, positive_out, negative_out)

                        running_loss = running_loss + (loss.item() * anchor.size(0))

                epoch_loss = running_loss / len(self.dataloaders[phase].dataset)
                if self.metric is not None :
                    epoch_metric = self.metric.compute()
                    if phase == "Train" :
                        self.metric_history.append(epoch_metric)
                    elif phase == "Valid" :
                        self.val_metric_history.append(epoch_metric)

                if phase == "Train":
                    self.loss_history.append(epoch_loss)
                elif phase == "Valid":
                    self.val_loss_history.append(epoch_loss)
                
                end = time.time()
                if self.metric is not None :
                    print(f"Phase: {phase}, Loss: {epoch_loss}, Metric: {epoch_metric}, Time: {round(end-start, 3)}")
                else :
                    print(f"Phase: {phase}, Loss: {epoch_loss}, Time: {round(end-start, 3)}")
    


        return model

In [None]:
trainer = TrainModel(model=model, train_dl=train_dl, criterion=criterion, optimizer=optimizer, valid_dl=test_dl)

In [None]:
model =  trainer.train_step(500)

In [None]:
model.eval()
test_dl = DataLoader(test_ds, batch_size=1, shuffle=True, num_workers=0)

dataiter = iter(test_dl)
with torch.no_grad():
    for i in range(64):
        anchor, positive, negative = next(dataiter)
        concat = torch.cat(( anchor, positive, negative), 0)
        anchor = anchor.to(device)
        positive = positive.to(device)
        negative = negative.to(device)
        anc_out = model(anchor)
        pos_out = model(positive)
        neg_out = model(negative)
        pos_distance = F.pairwise_distance(anc_out, pos_out)
        neg_distance = F.pairwise_distance(anc_out, neg_out)
        plot_image(torchvision.utils.make_grid(concat), f'Dissimilarities: ({pos_distance.item():.2f}, {neg_distance.item():.2f})')