In [1]:
%%capture
!pip install pillow
!pip install pandas
!pip install tqdm
!pip install transformers
!pip install huggingface_hub scikit-learn
!pip install torch==2.3.0 torchvision==0.18.0 torchaudio==2.3.0 --index-url https://download.pytorch.org/whl/cu118

In [2]:
import io
import os
import copy
import math
import torch
import pandas as pd
import torch.nn as nn
from PIL import Image
from tqdm import tqdm
import huggingface_hub
from torchvision.io import read_image
from torch.utils.data import Dataset, DataLoader
from sklearn.model_selection import train_test_split
from transformers import CLIPProcessor, CLIPModel, get_cosine_schedule_with_warmup

In [3]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Device: {device}")

model = CLIPModel.from_pretrained("openai/clip-vit-base-patch32").to(device)
processor = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch32")

for _, param in model.named_parameters():
    if not param.is_contiguous():
        param.data = param.data.contiguous()

max_length = processor.tokenizer.model_max_length

Device: cuda




In [4]:
def collate_fn(batch):
    text = [data[0] for data in batch]
    images = [data[1] for data in batch]
    return processor(images=images, text=text, return_tensors="pt", padding='longest', truncation=True, max_length=max_length, do_convert_rgb=False)

In [5]:
dataset_path = 'Dataset/train'
image_annotations_path = 'Dataset/crop-disease-data.csv'

In [6]:
class image_title_dataset(Dataset):
    def __init__(self, image_dir, annotation_path):
        self.image_dir = image_dir
        self.image_data = annotation_path

    def __len__(self):
        return len(self.image_data)

    def __getitem__(self, idx):
        label = self.image_data.iloc[idx, 1]
        image_path = os.path.join(self.image_dir, label, self.image_data.iloc[idx, 0])
        image = read_image(image_path)
        return label, image

In [7]:
class CLIPPreprocessor:
    def __init__(self, model, dataset_path, num_epochs, patience, dataset_annotation_path, batch_size=32, val_split=0.1, lr=5e-5, warm_up_ratio=0.2):
        self.model = model
        self.batch_size = batch_size
        self.val_split=val_split
        self.num_epochs=num_epochs
        self.patience=patience
        self.dataset_path = dataset_path

        self.dataset_annotation = pd.read_csv(dataset_annotation_path)

        self.optimizer = torch.optim.Adam(model.parameters(), lr=lr,betas=(0.9,0.98),eps=1e-6,weight_decay=0.2)
        self.img_criteria = nn.CrossEntropyLoss()
        self.txt_criteria = nn.CrossEntropyLoss()
        self.warm_up_ratio = warm_up_ratio

    def prepare_data(self):
        # Split the dataset
        self.train_path, self.val_path = train_test_split(self.dataset_annotation, test_size=self.val_split, stratify=self.dataset_annotation['label'], random_state=42)

        self.train_dataset = image_title_dataset(self.dataset_path, self.train_path)
        self.val_dataset = image_title_dataset(self.dataset_path, self.val_path)

        self.train_loader = DataLoader(self.train_dataset, batch_size=self.batch_size, collate_fn=collate_fn, shuffle=True)
        self.val_loader = DataLoader(self.val_dataset, batch_size=self.batch_size, collate_fn=collate_fn, shuffle=True)

    def prepare_scheduler(self):
        self.scheduler = get_cosine_schedule_with_warmup(self.optimizer,
                                                         num_warmup_steps=int(len(self.train_loader)* self.warm_up_ratio),
                                                         num_training_steps=len(self.train_loader)*self.num_epochs)

    def train_epoch(self, epoch):
        self.model.train()
        train_running_loss = 0
        pbar = tqdm(self.train_loader, total=len(self.train_loader))
        for batch in pbar:
            self.optimizer.zero_grad()
            batch = batch.to(device)

            # Forward pass
            outputs = self.model(**batch)
            logits_per_image = outputs.logits_per_image
            logits_per_text = outputs.logits_per_text

            # Compute loss
            ground_truth = torch.arange(len(batch['input_ids']),dtype=torch.long,device=device)
            total_loss = (self.img_criteria(logits_per_image,ground_truth) + self.txt_criteria(logits_per_text,ground_truth))/2

            # Backward pass
            total_loss.backward()
            self.optimizer.step()
            self.scheduler.step() #Here we adjust LR after each step as opposed to standard where we update each epoch.

            train_running_loss += total_loss
            pbar.set_description(f"Epoch {epoch}/{self.num_epochs}, Train_Loss: {total_loss.item():.4f}")

        return train_running_loss / len(self.train_loader.dataset)

    def validate(self):
        self.model.eval()
        val_running_loss = 0
        with torch.no_grad():
            pbar = tqdm(self.val_loader, total=len(self.val_loader))
            for batch in pbar:
                self.optimizer.zero_grad()
                batch = batch.to(device)

                # Forward pass
                outputs = self.model(**batch)
                logits_per_image = outputs.logits_per_image
                logits_per_text = outputs.logits_per_text

                # Compute loss
                ground_truth = torch.arange(len(batch['input_ids']),dtype=torch.long,device=device)
                total_loss = (self.img_criteria(logits_per_image,ground_truth) + self.txt_criteria(logits_per_text,ground_truth))/2

                val_running_loss += total_loss
                pbar.set_description(f"Val_Loss: {total_loss.item():.4f}")

        return val_running_loss / len(self.val_loader.dataset)

    def train(self):
        best_val_loss = float('inf')
        patience_counter = 0
        self.prepare_data()
        self.prepare_scheduler()

        for epoch in range(self.num_epochs):
            train_loss = self.train_epoch(epoch+1)
            val_loss = self.validate()

            current_lr = self.optimizer.param_groups[0]['lr']

            print(f"Epoch {epoch+1}/{self.num_epochs}")
            print(f"Train Loss: {train_loss:.4f}")
            print(f"Validation Loss: {val_loss:.4f}")
            print(f"Current LR: {current_lr:.6f}")

            if val_loss < best_val_loss:
                best_val_loss = val_loss
                patience_counter = 0
                self.save_model()
            else:
                patience_counter += 1

            if patience_counter >= self.patience:
                print("Early stopping triggered")
                break

    def save_model(self):
        # Ensure all tensors are contiguous
        for _, param in self.model.named_parameters():
            if not param.is_contiguous():
                param.data = param.data.contiguous()

        self.model.save_pretrained('Model_ckpt')

    def push_to_HFhub(self, repo_id, access_token):
        best_model = CLIPModel.from_pretrained('Model_ckpt')
        best_model.push_to_hub(repo_id=repo_id, token=access_token)

In [None]:
num_epochs=30
patience=10
preprocessor = CLIPPreprocessor(model, dataset_path, num_epochs, patience, image_annotations_path)
preprocessor.train()

In [None]:
repo_id = '<Your hugginface repository id>'
access_token= '<Your huggingface access token.>'
preprocessor.push_to_HFhub(repo_id, access_token)