In [1]:
import numpy as np # linear algebra
import pandas as pd # data processing, CSV file I/O (e.g. pd.read_csv)
import os
import matplotlib.pyplot as plt
import seaborn as sns


import torch
import torchvision.models as models
from torch.utils.data import Dataset, DataLoader
import torchvision.transforms as T
import torchvision.ops as ops
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F



# Importing machine learning utilities
from sklearn.model_selection import train_test_split
from sklearn.metrics import confusion_matrix

from einops import rearrange, repeat

# Importing libraries for medical image handling and dataset setup
import SimpleITK as sitk
import os
import json
import ast
import gc
import shutil
import glob
import sys
import random
from tqdm import tqdm
from pathlib import Path
from PIL import Image

In [2]:
pip install transformers==4.44.2 

Collecting transformers==4.44.2
  Downloading transformers-4.44.2-py3-none-any.whl.metadata (43 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m43.7/43.7 kB[0m [31m2.1 MB/s[0m eta [36m0:00:00[0m
Collecting huggingface-hub<1.0,>=0.23.2 (from transformers==4.44.2)
  Downloading huggingface_hub-0.36.0-py3-none-any.whl.metadata (14 kB)
Collecting tokenizers<0.20,>=0.19 (from transformers==4.44.2)
  Downloading tokenizers-0.19.1-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (6.7 kB)
Downloading transformers-4.44.2-py3-none-any.whl (9.5 MB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m9.5/9.5 MB[0m [31m65.4 MB/s[0m eta [36m0:00:00[0m00:01[0m00:01[0m
[?25hDownloading huggingface_hub-0.36.0-py3-none-any.whl (566 kB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m566.1/566.1 kB[0m [31m35.1 MB/s[0m eta [36m0:00:00[0m
[?25hDownloading tokenizers-0.19.1-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_6

In [3]:
import transformers, huggingface_hub
print(transformers.__version__, huggingface_hub.__version__)

4.44.2 0.36.0


In [4]:
os.environ["TRANSFORMERS_NO_ADDITIONAL_CHAT_TEMPLATES"] = "1"

In [5]:
from transformers import RobertaModel, RobertaTokenizer, BertTokenizer
from transformers import BertModel

In [6]:
#for dirname, _, filenames in os.walk('/kaggle/input'):
#    for filename in filenames:
#        print(os.path.join(dirname, filename))


captions_2014_path = "/kaggle/input/coco-image-caption/annotations_trainval2014/annotations/captions_train2014.json"
captions_2017_path = "/kaggle/input/coco-image-caption/annotations_trainval2017/annotations/captions_val2017.json"

# Load JSON
with open(captions_2014_path, "r") as f:
    captions_2014 = json.load(f)

with open(captions_2017_path, "r") as f:
    captions_2017 = json.load(f)

# Convert to DataFrames
df_captions_2014 = pd.DataFrame(captions_2014["annotations"])
df_captions_2017 = pd.DataFrame(captions_2017["annotations"])

print("Train 2014 captions:", df_captions_2014.shape)
print("Val 2017 captions:", df_captions_2017.shape)
print("\nSample caption:\n", df_captions_2017['caption'].iloc[0])

Train 2014 captions: (414113, 3)
Val 2017 captions: (25014, 3)

Sample caption:
 A black Honda motorcycle parked in front of a garage.


In [7]:
image_dir_2014 = "/kaggle/input/coco-image-caption/train2014/train2014"
image_dir_2017 = "/kaggle/input/coco-image-caption/val2017/val2017"

In [8]:
df_captions_2014['source_dir'] = image_dir_2014
df_captions_2017['source_dir'] = image_dir_2017

In [9]:
df_captions_2014['image_id'] = (
    'COCO_train2014_' + df_captions_2014['image_id'].astype(str).str.zfill(12)
)

df_captions_2017['image_id'] = (
    df_captions_2017['image_id'].astype(str).str.zfill(12)
)

In [10]:
pd.set_option('display.max_colwidth', None)

In [11]:
df_captions_2014['source_dir']

0         /kaggle/input/coco-image-caption/train2014/train2014
1         /kaggle/input/coco-image-caption/train2014/train2014
2         /kaggle/input/coco-image-caption/train2014/train2014
3         /kaggle/input/coco-image-caption/train2014/train2014
4         /kaggle/input/coco-image-caption/train2014/train2014
                                  ...                         
414108    /kaggle/input/coco-image-caption/train2014/train2014
414109    /kaggle/input/coco-image-caption/train2014/train2014
414110    /kaggle/input/coco-image-caption/train2014/train2014
414111    /kaggle/input/coco-image-caption/train2014/train2014
414112    /kaggle/input/coco-image-caption/train2014/train2014
Name: source_dir, Length: 414113, dtype: object

In [12]:
df_coco_unified = pd.concat([df_captions_2014, df_captions_2017], ignore_index=True)
df_coco_unified = df_coco_unified[['image_id', 'caption','source_dir']]
df_coco_unified['source'] = 'COCO'

# Rename 'image_id' for consistency if you plan to combine with Flickr, 
# although image access will differ (COCO uses 'image_id' to format the filename).
df_coco_unified.rename(columns={'image_id': 'unique_image_identifier'}, inplace=True)
df_coco_unified['unique_image_identifier'] = df_coco_unified['unique_image_identifier'].astype(str)

In [13]:
df_coco_unified.head()

Unnamed: 0,unique_image_identifier,caption,source_dir,source
0,COCO_train2014_000000318556,A very clean and well decorated empty bathroom,/kaggle/input/coco-image-caption/train2014/train2014,COCO
1,COCO_train2014_000000116100,A panoramic view of a kitchen and all of its appliances.,/kaggle/input/coco-image-caption/train2014/train2014,COCO
2,COCO_train2014_000000318556,A blue and white bathroom with butterfly themed wall tiles.,/kaggle/input/coco-image-caption/train2014/train2014,COCO
3,COCO_train2014_000000116100,A panoramic photo of a kitchen and dining room,/kaggle/input/coco-image-caption/train2014/train2014,COCO
4,COCO_train2014_000000379340,A graffiti-ed stop sign across the street from a red car,/kaggle/input/coco-image-caption/train2014/train2014,COCO


In [14]:
flickr_annotations_path = "/kaggle/input/flickr-image-dataset/flickr30k_images/results.csv"
flickr_image_path = "/kaggle/input/flickr-image-dataset/flickr30k_images/flickr30k_images"

In [15]:
df_flickr = pd.read_csv(flickr_annotations_path, delimiter = '|')
df_flickr['source_dir'] = flickr_image_path
df_flickr.columns

Index(['image_name', ' comment_number', ' comment', 'source_dir'], dtype='object')

In [16]:

try:
    # Use the column names and delimiter identified in our previous steps
    df_flickr = pd.read_csv(
        flickr_annotations_path,
        delimiter='|',
        names=['image_name', 'comment_number', 'caption'],
        header=None,
        encoding='utf-8',
        skiprows=1 
    )
    df_flickr['source_dir'] = flickr_image_path
    # Flickr image names are already unique strings (e.g., '1000092795.jpg')
    df_flickr = df_flickr[['image_name', 'caption','source_dir']]
    df_flickr.rename(columns={'image_name': 'unique_image_identifier'}, inplace=True)
    df_flickr['source'] = 'Flickr'
except Exception as e:
    print(f"Error loading Flickr CSV: {e}. Using a small dummy set for Flickr.")
    df_flickr = pd.DataFrame({
        'unique_image_identifier': ['dummy1.jpg', 'dummy2.jpg'],
        'caption': ['A placeholder image.', 'Another example sentence.'],
        'source': 'Flickr'
    })

In [17]:
df_combined = pd.concat([df_coco_unified, df_flickr], ignore_index=True)

print("Combined DataFrame Shape:", df_combined.shape)
print("Combined Data Sources:\n", df_combined['source'].value_counts())
print("Combined DataFrame Head:")
print(df_combined[df_combined['source_dir'].str.contains('2014', na=False)].head())

Combined DataFrame Shape: (598042, 4)
Combined Data Sources:
 source
COCO      439127
Flickr    158915
Name: count, dtype: int64
Combined DataFrame Head:
       unique_image_identifier  \
0  COCO_train2014_000000318556   
1  COCO_train2014_000000116100   
2  COCO_train2014_000000318556   
3  COCO_train2014_000000116100   
4  COCO_train2014_000000379340   

                                                       caption  \
0               A very clean and well decorated empty bathroom   
1     A panoramic view of a kitchen and all of its appliances.   
2  A blue and white bathroom with butterfly themed wall tiles.   
3               A panoramic photo of a kitchen and dining room   
4    A graffiti-ed stop sign across the street from a red car    

                                             source_dir source  
0  /kaggle/input/coco-image-caption/train2014/train2014   COCO  
1  /kaggle/input/coco-image-caption/train2014/train2014   COCO  
2  /kaggle/input/coco-image-caption/train2014/tra

In [18]:
from sklearn.model_selection import train_test_split

# 80% train, 10% val, 10% test
train_df, temp_df = train_test_split(df_combined, test_size=0.2, random_state=42)
val_df, test_df = train_test_split(temp_df, test_size=0.5, random_state=42)

print(len(train_df), len(val_df), len(test_df))


478433 59804 59805


In [19]:
train_df['source_dir']

52151                     /kaggle/input/coco-image-caption/train2014/train2014
595847    /kaggle/input/flickr-image-dataset/flickr30k_images/flickr30k_images
514118    /kaggle/input/flickr-image-dataset/flickr30k_images/flickr30k_images
160750                    /kaggle/input/coco-image-caption/train2014/train2014
477213    /kaggle/input/flickr-image-dataset/flickr30k_images/flickr30k_images
                                          ...                                 
110268                    /kaggle/input/coco-image-caption/train2014/train2014
259178                    /kaggle/input/coco-image-caption/train2014/train2014
365838                    /kaggle/input/coco-image-caption/train2014/train2014
131932                    /kaggle/input/coco-image-caption/train2014/train2014
121958                    /kaggle/input/coco-image-caption/train2014/train2014
Name: source_dir, Length: 478433, dtype: object

In [20]:
# --- Configuration ---
MAX_LEN = 128 # Fixed sequence length for the Transformer input
BERT_MODEL = 'bert-base-uncased'
tokenizer = BertTokenizer.from_pretrained(BERT_MODEL)

bert_model = BertModel.from_pretrained('bert-base-uncased')
bert_model.eval()

device = 'cuda' if torch.cuda.is_available() else 'cpu'
bert_model.to(device)

# --- 1. Tokenization Function (adapted from previous step) ---
def prepare_caption_for_roberta(caption, tokenizer, max_len):
    """Tokenizes a caption using RoBERTa's subword tokenizer."""
    # 'comment' is the raw string caption from the DataFrame
    encoding = tokenizer.encode_plus(
        str(caption),
        add_special_tokens=True,
        max_length=max_len,
        padding='max_length',
        truncation=True,
        return_attention_mask=True,
        return_tensors='pt'
    )
    # Squeeze(0) converts (1, MAX_LEN) to (MAX_LEN)
    return encoding['input_ids'].squeeze(0), encoding['attention_mask'].squeeze(0)

# --- 2. Custom PyTorch Dataset ---
class TextToImageDataset(Dataset):
    def __init__(self, df_combined, tokenizer, max_len, transform=None):
        self.data = df_combined
        self.tokenizer = tokenizer
        self.max_len = max_len
        self.transform = transform
        
    def __len__(self):
        return len(self.data)
    
    def __getitem__(self, idx):
        row = self.data.iloc[idx]
    
        # Text encoding
        input_ids, attention_mask = prepare_caption_for_roberta(row['caption'], self.tokenizer, self.max_len)
        #input_ids = input_ids.to(device)           # remove unsqueeze(0)
        #attention_mask = attention_mask.to(device)

        #with torch.no_grad():
            #outputs = bert_model(input_ids=input_ids.unsqueeze(0), attention_mask=attention_mask.unsqueeze(0))
            #text_embedding = outputs.last_hidden_state.squeeze(0)
        # Image loading
        img_name = row['unique_image_identifier'] + ('.jpg' if '.' not in row['unique_image_identifier'] else '')
        img_dir = row['source_dir']
        img_path = img_dir + '/' + img_name
    
        try:
            image = Image.open(img_path).convert('RGB')
        except FileNotFoundError:
            print(f"Warning: Image not found at {img_path}")
            image = Image.new('RGB', (256, 256), color='black')
    
        if self.transform:
            image = self.transform(image)
    
        return image, input_ids, attention_mask


tokenizer_config.json:   0%|          | 0.00/48.0 [00:00<?, ?B/s]

vocab.txt:   0%|          | 0.00/232k [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/466k [00:00<?, ?B/s]

config.json:   0%|          | 0.00/570 [00:00<?, ?B/s]



model.safetensors:   0%|          | 0.00/440M [00:00<?, ?B/s]

In [21]:
image_transforms = T.Compose([
    T.Resize((256, 256)), # Target size for the VAE
    T.ToTensor(),
    T.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])
train_dataset = TextToImageDataset(train_df, tokenizer, MAX_LEN, image_transforms)
val_dataset   = TextToImageDataset(val_df, tokenizer, MAX_LEN, image_transforms)
test_dataset  = TextToImageDataset(test_df, tokenizer, MAX_LEN, image_transforms)

In [22]:
train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True, num_workers=4)
val_loader   = DataLoader(val_dataset, batch_size=32, shuffle=False, num_workers=4)
test_loader  = DataLoader(test_dataset, batch_size=32, shuffle=False, num_workers=4)

print("Train samples:", len(train_dataset))
print("Val samples:", len(val_dataset))
print("Test samples:", len(test_dataset))

Train samples: 478433
Val samples: 59804
Test samples: 59805


In [23]:
class CrossAttention(nn.Module):
    def __init__(self, dim, context_dim, num_heads=8):
        super().__init__()
        self.num_heads = num_heads
        
        self.query = nn.Linear(dim, dim)
        self.key   = nn.Linear(context_dim, dim)
        self.value = nn.Linear(context_dim, dim)
        self.proj  = nn.Linear(dim, dim)
        
    def forward(self, x, context):
        B, N, C = x.shape  # image tokens
        _, M, D = context.shape  # text tokens

        q = self.query(x)
        #print(q.shape, 'query shape')
        k = self.key(context)
        #print(k.shape, 'key shape')
        v = self.value(context)
        #print(v.shape, 'after value shape')

        q = q.view(B, N, self.num_heads, C//self.num_heads).transpose(1,2)
        k = k.view(B, M, self.num_heads, C//self.num_heads).transpose(1,2)
        v = v.view(B, M, self.num_heads, C//self.num_heads).transpose(1,2)
        #print(q.shape, k.shape, v.shape)
        attn = (q @ k.transpose(-2, -1)) / (C**0.5)
        attn = attn.softmax(dim=-1)
        #print(attn.shape, 'attention shape')
        out = attn @ v
        out = out.transpose(1,2).reshape(B, N, C)
        #print(out.shape,'out shape')
        return self.proj(out)

In [24]:
class TextConditionedDecoder(nn.Module):
    def __init__(self, latent_dim=4, hidden_dim=256, text_dim=768):
        super().__init__()

        self.initial_conv = nn.Conv2d(latent_dim, hidden_dim, 3, padding=1)

        self.block1 = nn.Sequential(
            nn.ConvTranspose2d(hidden_dim, hidden_dim//2, kernel_size = 4, stride = 2, padding=1),
            nn.GroupNorm(8, hidden_dim//2),
            nn.SiLU(),
        )
        self.cross1 = CrossAttention(hidden_dim, text_dim)

        self.block2 = nn.Sequential(
            nn.ConvTranspose2d(hidden_dim//2, hidden_dim//4, kernel_size=4, stride=2, padding=1),
            nn.GroupNorm(8, hidden_dim//4),
            nn.SiLU(),
        )
        self.cross2 = CrossAttention(hidden_dim//2, text_dim)

        self.block3 = nn.Sequential(
            nn.ConvTranspose2d(hidden_dim//4, hidden_dim//8, kernel_size=4, stride=2, padding=1),
            nn.GroupNorm(8, hidden_dim//8),
            nn.SiLU(),
        )
        
        self.cross3 = CrossAttention(hidden_dim//4, text_dim)

        self.final_conv = nn.Conv2d(hidden_dim//8, 3, 3, padding=1)

    def forward(self, z, text_emb):
        # z: (B, 4, H/8, W/8)
        B, C, H, W = z.shape
        x = self.initial_conv(z)
        #print(x.shape,'after initial conv')
        # Tokenize spatial map for cross-attention
        x_tokens = x.flatten(2).transpose(1,2)  # (B, HW, C)
        #print(x_tokens.shape, 'after flatten')
        x_tokens = x_tokens + self.cross1(x_tokens, text_emb)
        #print(x_tokens.shape, 'after one cross')
        x = x_tokens.transpose(1,2).reshape(B, -1, H, W)
        #print(x.shape, 'after transpose and reshape')
        x = self.block1(x)
        #print(x.shape, 'after block1')

        x_tokens = x.flatten(2).transpose(1,2)
        #print(x_tokens.shape, 'after flatten 2')
        x_tokens = x_tokens + self.cross2(x_tokens, text_emb)
        #print(x_tokens.shape, 'after cross2')
        x = x_tokens.transpose(1,2).reshape(B, -1, H*2, W*2)
        #print(x.shape, 'after transpose and reshape')
        x = self.block2(x)
        #print(x.shape,'after block2')


        H, W = x.shape[2], x.shape[3]
        x_tokens = x.flatten(2).transpose(1,2)
        x_tokens = x_tokens + self.cross3(x_tokens, text_emb)
        x = x_tokens.transpose(1,2).reshape(B, -1, H, W)
        x = self.block3(x)  # -> 256x256
        #print(x.shape, 'final block shape')
        return torch.sigmoid(self.final_conv(x))


In [25]:
class VAEEncoder(nn.Module):
    def __init__(self, in_channels=3, latent_dim=4, hidden_dim=256):
        super().__init__()

        self.conv1 = nn.Sequential(
            nn.Conv2d(in_channels, hidden_dim//2, 3, stride=2, padding=1),  # 256→128
            nn.GroupNorm(8, hidden_dim//2),
            nn.SiLU(),
        )

        self.conv2 = nn.Sequential(
            nn.Conv2d(hidden_dim//2, hidden_dim, 3, stride=2, padding=1),    # 128→64
            nn.GroupNorm(8, hidden_dim),
            nn.SiLU(),
        )

        self.conv3 = nn.Sequential(
            nn.Conv2d(hidden_dim, hidden_dim, 3, stride=2, padding=1),      # 64→32
            nn.GroupNorm(8, hidden_dim),
            nn.SiLU(),
        )

        # Final conv layers before producing mean & logvar
        self.conv_out = nn.Sequential(
            nn.Conv2d(hidden_dim, hidden_dim, 3, padding=1),
            nn.GroupNorm(8, hidden_dim),
            nn.SiLU(),
        )

        # Mean and log-variance projection heads
        self.to_mean   = nn.Conv2d(hidden_dim, latent_dim, 3, padding=1)
        self.to_logvar = nn.Conv2d(hidden_dim, latent_dim, 3, padding=1)

    def forward(self, x):
        # x : (B, 3, 256, 256)
        x = self.conv1(x)  # -> (B, 128, 128, 128)
        x = self.conv2(x)  # -> (B, 256,  64,  64)
        x = self.conv3(x)  # -> (B, 256,  32,  32)

        x = self.conv_out(x)  

        mu     = self.to_mean(x)     # (B, 4, 32, 32)
        logvar = self.to_logvar(x)   # (B, 4, 32, 32)

        return mu, logvar


def reparameterize(mu, logvar):
    std = torch.exp(0.5 * logvar)
    eps = torch.randn_like(std)
    return mu + eps * std


In [26]:
def kl_loss(mu, logvar, reduction='mean'):
    """
    KL divergence between encoded latent distribution and N(0,1)
    """
    kl = -0.5 * (1 + logvar - mu.pow(2) - logvar.exp())
    if reduction == 'sum':
        return kl.sum()
    return kl.mean()


In [27]:
device = 'cuda' if torch.cuda.is_available() else 'cpu'

encoder = VAEEncoder().to(device)
decoder = TextConditionedDecoder().to(device)
bert_model.eval().to(device)  # freeze BERT for now

optimizer = torch.optim.AdamW(list(encoder.parameters()) + list(decoder.parameters()), lr=2e-4)

In [28]:
torch.cuda.empty_cache()

In [None]:
num_epochs = 5

for epoch in range(num_epochs):
    encoder.train()
    decoder.train()
    
    for batch in tqdm(train_loader, desc=f"Epoch {epoch}"):
        images, input_ids, attention_mask = batch  # images: (B,3,256,256), text_embeddings: (B, max_len, 768)
        images = images.to(device)
        input_ids = input_ids.to(device)
        attention_mask = attention_mask.to(device)

        with torch.no_grad():
            outputs = bert_model(input_ids=input_ids, attention_mask=attention_mask)
            text_embedding = outputs.last_hidden_state
            text_embedding = text_embedding.to(device)
            
        #print(text_embeddings.shape)
        optimizer.zero_grad()

        # --- Encode image ---
        mu, logvar = encoder(images)
        z = reparameterize(mu, logvar)  # (B, latent_dim, H, W)
        #print(z.shape)
        # --- Decode with text conditioning ---
        recon = decoder(z, text_embedding)  # output: (B,3,256,256)

        # --- Compute Loss ---
        recon_loss = F.mse_loss(recon, images, reduction='mean')
        kld = kl_loss(mu, logvar, reduction='mean')
        loss = recon_loss + kld

        # --- Backprop ---
        loss.backward()
        optimizer.step()

        del images, input_ids, attention_mask, text_embedding, mu, logvar, z, recon, loss
        torch.cuda.empty_cache()

    print(f"Epoch {epoch} | Recon Loss: {recon_loss.item():.4f} | KL Loss: {kld.item():.4f}")


Epoch 0:  38%|███▊      | 5644/14952 [1:38:59<2:45:22,  1.07s/it]

In [None]:
!kill 2584


In [None]:
# Define paths
encoder_path = "/kaggle/working/encoder.pth"
decoder_path = "decoder.pth"
optimizer_path = "optimizer.pth"

# Save models
torch.save(encoder.state_dict(), encoder_path)
torch.save(decoder.state_dict(), decoder_path)
torch.save(optimizer.state_dict(), optimizer_path)

print("Models and optimizer saved successfully!")