# import 

In [None]:
import torch 
from torch import nn 
import os 
import matplotlib.pyplot as plt 
from torch.utils.data import Dataset , DataLoader
from torchvision.transforms import Resize
from torchvision.io import read_image
from torchvision import transforms

# custom Dataset

In [None]:
class CustomDataset(Dataset) : 
    def __init__(self , anchors_folder , images_folder ,label , transform = None  ): 
        self.anchors_folder = anchors_folder
        self.images_folder  = images_folder 
        self.transform = transform 
        self.label = label
        self.anchors_paths = os.listdir(os.path.join(os.getcwd() , 'data' , anchors_folder)) 
        self.images_paths = os.listdir(os.path.join(os.getcwd() , 'data' , images_folder))
        self.anchors_paths = self.anchors_paths[:1500] 
        self.images_paths = self.images_paths[:1500]
    def __len__(self ) : 
        return len(self.images_paths) 
    def __getitem__(self , idx ) : 
        anchor_image = read_image(os.path.join(os.getcwd() , 'data' , self.anchors_folder, self.anchors_paths[idx])) 
        image = read_image(os.path.join(os.getcwd() , 'data'  , self.images_folder, self.images_paths[idx]))
        anchor_image = anchor_image / 255
        image = image / 255
        
        if self.transform : 
            anchor_image = self.transform(anchor_image) 
            image = self.transform(image) 
        return anchor_image , image , self.label
        
        

In [None]:
positive_data = CustomDataset('anchors' , 'positive'  , 1 ,transforms.Compose([Resize((100,100))]))
negative_data = CustomDataset('anchors' , 'negative'  , 0 ,transforms.Compose([Resize((100,100))]))


In [None]:
3000 * .8

In [None]:
dataset = torch.utils.data.ConcatDataset([positive_data  ,negative_data])

In [None]:
data_loader = DataLoader(dataset, batch_size=16,
                        shuffle=True)




In [None]:
train_set, val_set = torch.utils.data.random_split(data_loader, [ 150,38 ])

# model

In [None]:
class L1Dist(nn.Module):
    
    def __init__(self, **kwargs):
        super().__init__()
       
    # Magic happens here - similarity calculation
    def forward(self, input_embedding, validation_embedding):
        return torch.abs(input_embedding - validation_embedding)

In [None]:
class embedding(nn.Module) : 
    def __init__(self ) : 
        super().__init__()
        self.b1 =  nn.Sequential(nn.Conv2d(3 , 64 , (10 ,10 ) ) 
                               , nn.ReLU()
                               ,nn.MaxPool2d((2,2)) )
        self.b2 =  nn.Sequential(nn.Conv2d(64 , 128 , (7 ,7 ) ) 
                               , nn.ReLU()
                               ,nn.MaxPool2d((2,2)) )
        self.b3 =  nn.Sequential(nn.Conv2d(128 , 128 , (4 ,4 ) ) 
                               , nn.ReLU()
                               ,nn.MaxPool2d((2,2)) )
        
        self.b4 = nn.Sequential(nn.Conv2d(128 , 256 , (4 ,4 ) ) 
                               , nn.ReLU()
                               ,nn.Flatten() )
        
        self.out = nn.Linear(6400 , 4096) 
        self.s = nn.Sigmoid()
        
    def forward(self , x) : 
        x = self.b1(x)
        x = self.b2(x)
        x = self.b3(x)
        x = self.b4(x)
        x = self.s(self.out(x))
        return x 

        
        

In [None]:
class Model(nn.Module): 
    def __init__(self) : 
        super().__init__()
        self.embedding  = embedding()
        self.distance = L1Dist()
        self.out = nn.Linear(4096 , 1)
        self.s = nn.Sigmoid()
    def forward(self , x_1 , x_2)  : 
        return self.s(self.out(self.distance(self.embedding(x_1) , self.embedding(x_2))))
        
        
        

In [None]:
model = Model()

In [None]:
optim = torch.optim.Adam(model.parameters() ,  lr = .0001)
loss_fn = nn.BCELoss()

# train step 

In [None]:
def train_step(model: torch.nn.Module, 
               dataloader: torch.utils.data.DataLoader, 
               loss_fn: torch.nn.Module, 
               optimizer: torch.optim.Optimizer):
    # Put model in train mode
    model.train()
    
    # Setup train loss and train accuracy values
    train_loss, train_acc = 0, 0
    
    # Loop through data loader data batches
    for batch, (anchors , images, labels ) in enumerate(dataloader):

        # 1. Forward pass
        y_pred = model(anchors , images)
        y_pred = y_pred.type(torch.float).squeeze(-1)
        # 2. Calculate  and accumulate loss
        loss = loss_fn(y_pred, labels.type(torch.float))
        train_loss += loss.item() 

        # 3. Optimizer zero grad
        optimizer.zero_grad()

        # 4. Loss backward
        loss.backward()

        # 5. Optimizer step
        optimizer.step()

        # Calculate and accumulate accuracy metric across all batches
        y_pred_class = torch.round(y_pred)
        train_acc += (y_pred_class == labels).sum().item()/len(y_pred)

    # Adjust metrics to get average loss and accuracy per batch 
    train_loss = train_loss / len(dataloader)
    train_acc = train_acc / len(dataloader)
    return train_loss, train_acc

In [None]:
def test_step(model: torch.nn.Module, 
              dataloader: torch.utils.data.DataLoader, 
              loss_fn: torch.nn.Module):
    # Put model in eval mode
    model.eval() 
    
    # Setup test loss and test accuracy values
    test_loss, test_acc = 0, 0
    
    # Turn on inference context manager
    with torch.inference_mode():
        # Loop through DataLoader batches
        for batch, (anchors , images, labels) in enumerate(dataloader):
            # 1. Forward pass
            test_pred_logits = model(anchors  , images)
            test_pred_logits = test_pred_logits.type(torch.float).squeeze(-1)
            # 2. Calculate  and accumulate loss
            loss = loss_fn(test_pred_logits, labels.type(torch.float))
            test_loss += loss.item()
            
            # Calculate and accumulate accuracy
            test_pred_labels = test_pred_logits.round()
            test_acc += ((test_pred_labels == labels).sum().item()/len(test_pred_labels))
            
    # Adjust metrics to get average loss and accuracy per batch 
    test_loss = test_loss / len(dataloader)
    test_acc = test_acc / len(dataloader)
    return test_loss, test_acc

In [None]:
from tqdm.auto import tqdm

# 1. Take in various parameters required for training and test steps
def train(model: torch.nn.Module, 
          train_dataloader: torch.utils.data.DataLoader, 
          test_dataloader: torch.utils.data.DataLoader, 
          optimizer: torch.optim.Optimizer,
          loss_fn: torch.nn.Module = nn.CrossEntropyLoss(),
          epochs: int = 5):
    
    # 2. Create empty results dictionary
    results = {"train_loss": [],
        "train_acc": [],
        "test_loss": [],
        "test_acc": []
    }
    
    # 3. Loop through training and testing steps for a number of epochs
    for epoch in tqdm(range(epochs)):
        train_loss, train_acc = train_step(model=model,
                                           dataloader=train_dataloader,
                                           loss_fn=loss_fn,
                                           optimizer=optimizer)
        test_loss, test_acc = test_step(model=model,
            dataloader=test_dataloader,
            loss_fn=loss_fn)
        
        # 4. Print out what's happening
        print(
            f"Epoch: {epoch+1} | "
            f"train_loss: {train_loss:.4f} | "
            f"train_acc: {train_acc:.4f} | "
            f"test_loss: {test_loss:.4f} | "
            f"test_acc: {test_acc:.4f}"
        )

        # 5. Update results dictionary
        results["train_loss"].append(train_loss)
        results["train_acc"].append(train_acc)
        results["test_loss"].append(test_loss)
        results["test_acc"].append(test_acc)

    # 6. Return the filled results at the end of the epochs
    return results

In [None]:
model_0_results = train(model=model, 
                        train_dataloader=train_loader,
                        test_dataloader=val_loader,
                        optimizer=optim,
                        loss_fn=loss_fn, 
                        epochs=5)

In [None]:
torch.save(model,os.path.join(os.getcwd() , 'pytorch_model.pth'))