In [None]:
!pip install ftfy regex tqdm
!pip install git+https://github.com/openai/CLIP.git

In [None]:
import pandas as pd
from torch.utils.data import Dataset, DataLoader
import torch
import torch.nn as nn
import torch.optim as optim
import clip
from PIL import Image
from sklearn.model_selection import train_test_split
import numpy as np

In [None]:
device = "cuda:0" if torch.cuda.is_available() else "cpu" # If using GPU then use mixed precision training.
model, preprocess = clip.load("RN50",device=device,jit=False) #Must set jit=False for training

In [None]:
from torchvision import transforms

class ImageTextDataset(Dataset):
    def __init__(self, list_image_path, list_txt, transform=None):
        # Prepend the path to each image file name
        base_path = "/home/hous/Desktop/LLAVA/CarDD_release/CarDD_COCO/data/"
        self.image_path = [base_path + file_name for file_name in list_image_path]
        self.title = clip.tokenize(list_txt)
        self.transform = transform

    def __getitem__(self, idx):
        image = Image.open(self.image_path[idx]).convert("RGB")  # Convert to RGB
        if self.transform:
            image = self.transform(image)
        title = self.title[idx].to(device)
        return image, title
    
    def __len__(self):
        return len(self.title)

In [None]:
# Read data from CSV file
csv_file = 'pair.csv'  # Replace with your CSV file path
data = pd.read_csv(csv_file)

In [None]:
from torchvision.transforms import RandomCrop, GaussianBlur, RandomGrayscale

In [None]:
BATCH_SIZE = 2
# Split data into training and validation
train_data, val_data = train_test_split(data, test_size=0.2)  # 20% for validation
resize_size = 224
transform = transforms.Compose([
    transforms.Resize((resize_size + 20, resize_size + 20)),  # slightly bigger for cropping
    transforms.RandomCrop(resize_size),  # random cropping
    transforms.RandomHorizontalFlip(),
    transforms.RandomRotation(15),
    transforms.ColorJitter(brightness=0.5, contrast=0.5, saturation=0.5),
    transforms.RandomApply([GaussianBlur(kernel_size=3)], p=0.2),  # apply Gaussian Blur with a probability of 0.2
    transforms.RandomGrayscale(p=0.2),  # convert to grayscale with a probability of 0.2
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.48145466, 0.4578275, 0.40821073], std=[0.26862954, 0.26130258, 0.27577711]),
])


train_dataset = ImageTextDataset(train_data['file_name'].tolist(), train_data['damages'].tolist(), transform=transform)
val_dataset = ImageTextDataset(val_data['file_name'].tolist(), val_data['damages'].tolist(), transform=transforms.Compose([
    transforms.Resize((resize_size, resize_size)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.48145466, 0.4578275, 0.40821073], std=[0.26862954, 0.26130258, 0.27577711]),
]))


# Create DataLoaders for both datasets
train_dataloader = DataLoader(train_dataset, batch_size=BATCH_SIZE)
val_dataloader = DataLoader(val_dataset, batch_size=BATCH_SIZE)

In [None]:
# Function to evaluate the model on validation data
def evaluate_model(model, dataloader):
    model.eval()
    total_loss = 0
    with torch.no_grad():
        for batch in dataloader:
            images, texts = batch
            images = images.to(device)
            texts = texts.to(device)
            logits_per_image, logits_per_text = model(images, texts)
            ground_truth = torch.arange(len(images), dtype=torch.long, device=device)
            total_loss += (loss_img(logits_per_image, ground_truth) + loss_txt(logits_per_text, ground_truth)).item()
    return total_loss / len(dataloader)

In [None]:
def convert_models_to_fp32(model):
    for p in model.parameters():
        p.data = p.data.float()
        p.grad.data = p.grad.data.float()


if device == "cpu":
  model.float()
else :
  clip.model.convert_weights(model)

loss_img = nn.CrossEntropyLoss()
loss_txt = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=1e-5, betas=(0.9, 0.98), eps=1e-6, weight_decay=1e-4)

In [None]:
from torch.optim.lr_scheduler import ReduceLROnPlateau

scheduler = ReduceLROnPlateau(optimizer, 'min', patience=5, factor=0.3, verbose=True)

In [None]:
import torch
from tqdm import tqdm
import os 

# Training loop
EPOCHS = 100  # Set the number of epochs
best_loss = float('inf')
start_epoch = 0  # Default start epoch

# Check if a saved model exists
saved_model_path = 'best_matching_model.pth'
if os.path.isfile(saved_model_path):
    checkpoint = torch.load(saved_model_path)
    model.load_state_dict(checkpoint['model_state'])
    optimizer.load_state_dict(checkpoint['optimizer_state'])
    start_epoch = checkpoint['epoch']
    best_loss = checkpoint['best_loss']
    print(f"Resuming training from epoch {start_epoch} with best validation loss {best_loss}")

for epoch in range(start_epoch, EPOCHS):
    model.train()
    total_train_loss = 0

    # Adding tqdm for training progress
    train_loop = tqdm(train_dataloader, total=len(train_dataloader), leave=True)
    for images, texts in train_loop:
        train_loop.set_description(f'Epoch {epoch+1}/{EPOCHS} [Training]')

        optimizer.zero_grad()
        logits_per_image, logits_per_text = model(images, texts)
        ground_truth = torch.arange(images.size(0), device=device)
        total_loss = (loss_img(logits_per_image, ground_truth) + loss_txt(logits_per_text, ground_truth)) / 2
        total_loss.backward()
        optimizer.step()
        total_train_loss += total_loss.item()

        train_loop.set_postfix(train_loss=total_loss.item())

    avg_train_loss = total_train_loss / len(train_dataloader)

    # Progress for validation
    val_loss = evaluate_model(model, val_dataloader)
    tqdm.write(f'Epoch {epoch+1}/{EPOCHS}, Training Loss: {avg_train_loss:.4f}, Validation Loss: {val_loss:.4f}')

    # Save the best model along with epoch and optimizer state
    if val_loss < best_loss:
        best_loss = val_loss
        torch.save({
            'epoch': epoch,
            'model_state': model.state_dict(),
            'optimizer_state': optimizer.state_dict(),
            'best_loss': best_loss
        }, saved_model_path)

print("Training completed.")