In [1]:
import h5py
from PIL import Image
from transformers import CLIPProcessor, CLIPModel
import torch
from torchvision import transforms
import random
import numpy as np
import matplotlib.pyplot as plt
from collections import Counter
from torchview import draw_graph
from torch import nn
import torch.optim as optim
import tqdm
import torch.utils.data as utils

In [2]:
class CustomDataset(utils.Dataset):
    def __init__(self, file):
        self.images, self.labels = [], []

        with h5py.File(file, 'r') as dataset:
            for key in dataset.keys():
                self.images.append(dataset[key+'/img'][:])
                self.labels.append(dataset[key+'/label'][()])
        
    def __len__(self):
        return len(self.labels)
    
    def __getitem__(self,idx):
        return self.images[idx], self.labels[idx]
        
def init_model(module):
    if isinstance(module, nn.Linear):
        nn.init.xavier_normal_(module.weight)
        if module.bias is not None:
            nn.init.zeros_(module.bias)
        
class Model(nn.Module):
    def __init__(self):
        super().__init__()
        self.PLIP = model = CLIPModel.from_pretrained("vinid/plip")
        self.processor = CLIPProcessor.from_pretrained("vinid/plip")
        self.vision_model = CLIPModel.from_pretrained("vinid/plip").vision_model
        for parameter in self.vision_model.parameters():
            parameter.requires_grad = False

        self.classifier = nn.Sequential(nn.Linear(768, 384), nn.ReLU(),
                                        nn.Linear(384, 192), nn.ReLU(),
                                        nn.Linear(192, 48), nn.ReLU(),
                                        nn.Linear(48,2))
        self.apply(init_model)
            
    def forward(self, image):
        inputs = self.processor(images=image, return_tensors="pt", padding=True, do_rescale = False)
        with torch.no_grad():
            embedding = self.vision_model(pixel_values=inputs['pixel_values']).last_hidden_state[:, 0, :]
        return self.classifier(embedding)

In [None]:
train_batch_size, val_batch_size, test_batch_size = 100, 100, 100
train_set = utils.DataLoader(CustomDataset('train.h5') , batch_size = train_batch_size, shuffle = True)
validation_set = utils.DataLoader(CustomDataset('val.h5') , batch_size = val_batch_size, shuffle = True)
test_set = utils.DataLoader(CustomDataset('test.h5') , batch_size = test_batch_size, shuffle = True)
model = Model()
device = 'cuda' if torch.cuda.is_available() else 'cpu'
model.to(device)
loss_fn = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=3e-4)

In [139]:
def use(epochs):
    train_accs, val_accs, test_accs = [], [], []
    for epoch in tqdm(range(epochs)):
        train_loss, train_acc = 0, 0
        for image, label in train_set:
                model.train()
                y_pred = model(image)
                loss = loss_fn(y_pred, label)
                train_loss += loss.item()
                train_acc += (torch.argmax(y_pred) == label).sum().item()
                optimizer.zero_grad()
                loss.backward()
                optimizer.step()

        train_loss /= len(train_set*train_batch_size)
        train_acc = train_acc/(len(train_set*train_batch_size))*100
        train_accs.append(train_acc)

        # Testing loop
        test_loss, test_acc = 0, 0
        model.eval()
        with torch.inference_mode():
            for image, label in test_set:
                    test_pred = model(image)  
                    testloss = loss_fn(test_pred, label)
                    test_loss += testloss.item()
                    test_acc += (torch.argmax(test_pred) == label).sum().item()
            
        test_loss /= len(test_set*test_batch_size) 
        test_acc = test_acc/(len(train_set*train_batch_size))*100
        test_accs.append(test_acc)

        if (test_acc == 100) and (train_acc==100):
            last_epoch=epoch+1
            print(f"Optimum reached at epoch {last_epoch}")
            break
        last_epoch=epochs
    plt.figure(figsize=(5,5))
    i = range(1,last_epoch+1)
    plt.plot(i, train_accs, color = 'blue', label='Train accuracy')
    plt.plot(i, test_accs, color = 'green', label='Test accuracy')
    plt.gca().set_ylim([0, 110])
    plt.legend(loc='lower right')
    plt.ylabel('Accuracy')
    plt.xlabel('Epochs')
    plt.title('Accuracy over epochs')
    plt.show()