In [8]:
import os
import pandas as pd
from PIL import Image
from torch.utils.data import Dataset, DataLoader
import clip
import torch
import re
import nltk
from nltk.corpus import stopwords
from nltk.tokenize import word_tokenize
nltk.download('punkt')
nltk.download('stopwords')

[nltk_data] Downloading package punkt to
[nltk_data]     C:\Users\ninja\AppData\Roaming\nltk_data...
[nltk_data]   Package punkt is already up-to-date!
[nltk_data] Downloading package stopwords to
[nltk_data]     C:\Users\ninja\AppData\Roaming\nltk_data...
[nltk_data]   Package stopwords is already up-to-date!


True

Loading Training Dataset

In [2]:
data_directory = "../QL4POMR/Datasets/ROCO2/"
df_train = pd.read_csv(data_directory+"train_captions.csv")
df_train['Image'] = df_train['ID'] + '.jpg'
df_train = df_train.drop(columns=['ID'])
df_train.head()

Unnamed: 0,Caption,Image
0,Head CT demonstrating left parotiditis.,ROCOv2_2023_train_000001.jpg
1,Acquired renal cysts in end-stage renal failur...,ROCOv2_2023_train_000002.jpg
2,Computed tomography of the chest showing the r...,ROCOv2_2023_train_000003.jpg
3,Lateral view of the sacrum showing the low con...,ROCOv2_2023_train_000004.jpg
4,Thoracic CT scan showing perihilar pulmonary l...,ROCOv2_2023_train_000005.jpg


Pre-Processing

In [9]:
# Define the preprocessing function
def preprocess_caption(caption):
    caption = caption.lower()
    caption = re.sub(r'[^\w\s]', '', caption)
    tokens = word_tokenize(caption)
    stop_words = set(stopwords.words('english'))
    tokens = [word for word in tokens if word not in stop_words]
    preprocessed_caption = ' '.join(tokens)
    return preprocessed_caption

In [15]:
# Apply preprocessing to the captions
df_train['Caption'] = df_train['Caption'].apply(preprocess_caption)
df_train.head()
max_length = df_train['Caption'].str.len().max()
print("The maximum length of the string column is:", max_length)
avg_length = df_train['Caption'].str.len().mean()
print("The average length of the string column is:", avg_length)

The maximum length of the string column is: 2278
The average length of the string column is: 115.84043348902149


In [17]:
class ROCO2Dataset(Dataset):
    def __init__(self, dataframe, image_dir, transform=None):
        self.dataframe = dataframe
        self.image_dir = image_dir
        self.transform = transform

    def __len__(self):
        return len(self.dataframe)

    def __getitem__(self, idx):
        img_name = os.path.join(self.image_dir, self.dataframe.iloc[idx, 1])
        image = Image.open(img_name).convert("RGB")
        caption = self.dataframe.iloc[idx, 0]

        if self.transform:
            image = self.transform(image)
            
        # Truncate the caption if it is too long for CLIP(Context length max 77)
        if len(caption) > 77:
            caption = caption[:77]
        
        return image, caption

In [20]:
# Prepare dataset and dataloader
device = "cuda" if torch.cuda.is_available() else "cpu"
model, preprocess = clip.load("ViT-B/32", device=device)

# Modify dataset to use CLIP's preprocessing
dataset = ROCO2Dataset(df_train, data_directory+"train_images/train/", transform=preprocess)
#Batch size was originally 32
dataloader = DataLoader(dataset, batch_size=64, shuffle=True)

# Define the optimizer
optimizer = torch.optim.Adam(model.parameters(), lr=1e-4)

In [23]:
# Training loop
# Originally 10 epochs
num_epochs = 10
for epoch in range(num_epochs):
    model.train()
    total_loss = 0
    for images, texts in dataloader:
        images = images.to(device)
        texts = clip.tokenize(texts).to(device)
        
        optimizer.zero_grad()
        logits_per_image, logits_per_text = model(images, texts)
        ground_truth = torch.arange(len(images), device=device)
        loss = (torch.nn.CrossEntropyLoss()(logits_per_image, ground_truth) +
                torch.nn.CrossEntropyLoss()(logits_per_text, ground_truth)) / 2
        loss.backward()
        optimizer.step()
        
        total_loss += loss.item()
    
    print(f"Epoch {epoch+1}/{num_epochs}, Loss: {total_loss / len(dataloader)}")

# Save the model
torch.save(model.state_dict(), 'clip_roco2.pth')

Epoch 1/10, Loss: nan
Epoch 2/10, Loss: nan
Epoch 3/10, Loss: nan
Epoch 4/10, Loss: nan
Epoch 5/10, Loss: nan
Epoch 6/10, Loss: nan
Epoch 7/10, Loss: nan
Epoch 8/10, Loss: nan
Epoch 9/10, Loss: nan
Epoch 10/10, Loss: nan
