In [None]:

!mkdir -p /content/drive/MyDrive/medical-image-caption-data/

%cd /content/drive/MyDrive/


!tar -xvzf NLMCXR_png.tgz -C ./medical-image-caption-data/


!tar -xvzf NLMCXR_reports.tgz -C ./medical-image-caption-data/


[1;30;43mStreaming output truncated to the last 5000 lines.[0m
./CXR3471_IM-1687-2002.png
./CXR3810_IM-1920-2001.png
./CXR122_IM-0147-0001-0001.png
./CXR2927_IM-1329-1001.png
./CXR594_IM-2187-2001.png
./CXR683_IM-2254-1001.png
./CXR234_IM-0906-0001-0001.png
./CXR3292_IM-1572-1001.png
./CXR81_IM-2343-2001.png
./CXR1347_IM-0225-3001.png
./CXR2543_IM-1054-3001.png
./CXR3315_IM-1586-2001.png
./CXR2471_IM-1002-1001.png
./CXR3468_IM-1684-0001-0004.png
./CXR3060_IM-1426-1003.png
./CXR57_IM-2170-1001-0002.png
./CXR3858_IM-1953-4004.png
./CXR1541_IM-0352-2001.png
./CXR2845_IM-1254-2001.png
./CXR2549_IM-1057-1001.png
./CXR210_IM-0730-1001.png
./CXR2117_IM-0745-1001.png
./CXR3978_IM-2037-0001-0002.png
./CXR3881_IM-1969-1001.png
./CXR3854_IM-1950-2001.png
./CXR2195_IM-0805-2001.png
./CXR439_IM-2078-2001.png
./CXR1280_IM-0187-1001.png
./CXR666_IM-2241-1001.png
./CXR2804_IM-1235-1001.png
./CXR3783_IM-1898-1001.png
./CXR3533_IM-1726-2001.png
./CXR516_IM-2130-2001.png
./CXR379_IM-1903-2001.png
./CXR

## CONNECT TO MY GOOGLE DRIVE

In [2]:
from google.colab import drive
drive.mount('/content/drive')

Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).


In [3]:

image_dir = '/content/drive/MyDrive/medical-image-caption-data'
report_dir = '/content/drive/MyDrive/medical-report'


In [4]:
import os
print("Number of images:", len(os.listdir(image_dir)))
print("Number of reports:", len(os.listdir(report_dir)))

Number of images: 2065
Number of reports: 3955


## import the libraries

In [5]:
import os
from PIL import Image
import xml.etree.ElementTree as ET
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
from transformers import AutoTokenizer
from tqdm import tqdm
import torch.nn as nn


## Pre-process image

In [6]:
from torchvision import transforms

def get_image_transform(train=True):
    if train:
        # data enhancement in training
        transform = transforms.Compose([
            transforms.Resize((224, 224)),
            transforms.RandomHorizontalFlip(p=0.5),
            transforms.RandomRotation(degrees=10),
            transforms.ColorJitter(brightness=0.2, contrast=0.2),
            transforms.ToTensor(),
            transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                 std=[0.229, 0.224, 0.225])
        ])
    else:
        # testing/evaluate progress
        transform = transforms.Compose([
            transforms.Resize((224, 224)),
            transforms.ToTensor(),
            transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                 std=[0.229, 0.224, 0.225])
        ])
    return transform


In [7]:

train_transform = get_image_transform(train=True)

val_transform = get_image_transform(train=False)


## Initialize Tokenizer --- BERT --- Compare to the T-5 smaller later


In [9]:
from transformers import AutoTokenizer

# 加载 HuggingFace 的 tokenizer（选用 BERT）
tokenizer = AutoTokenizer.from_pretrained("bert-base-uncased")


The secret `HF_TOKEN` does not exist in your Colab secrets.
To authenticate with the Hugging Face Hub, create a token in your settings tab (https://huggingface.co/settings/tokens), set it as secret in your Google Colab and restart your session.
You will be able to reuse this secret in all of your notebooks.
Please note that authentication is recommended but still optional to access public models or datasets.


tokenizer_config.json:   0%|          | 0.00/48.0 [00:00<?, ?B/s]

config.json:   0%|          | 0.00/570 [00:00<?, ?B/s]

vocab.txt:   0%|          | 0.00/232k [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/466k [00:00<?, ?B/s]

## Customize the dataset class

In [10]:
from PIL import Image
class IUXRayDataset(Dataset):
    def __init__(self, image_dir, report_dir, tokenizer, max_length=100):
        self.image_dir = image_dir
        self.report_dir = report_dir
        self.tokenizer = tokenizer
        self.max_length = max_length
        self.samples = self.load_samples()

        # process image
        self.transform = transforms.Compose([
            transforms.Resize((224, 224)),
            transforms.ToTensor()
        ])

    def load_samples(self):
        print("===> Start loading samples from report and image...")
        samples = []
        count_img = 0
        count_match = 0
        count_missing_report = 0

        for img_name in tqdm(os.listdir(self.image_dir), desc="Parsing image-report pairs"):
            if not img_name.endswith('.png'):
                continue

            count_img += 1

            # extract the id
            base_id = img_name.split('_')[0]  # CXR1005
            report_id = base_id.replace('CXR', '')  # 1005

            # do the xml path
            xml_path = os.path.join(self.report_dir, report_id + ".xml")

            if not os.path.exists(xml_path):
                count_missing_report += 1
                continue

            # load the  xml
            try:
                tree = ET.parse(xml_path)
                root = tree.getroot()
            except:
                continue

            findings, impression = "", ""
            for abstract in root.findall('.//AbstractText'):
                label = abstract.attrib.get('Label')
                if label == 'FINDINGS':
                    findings = abstract.text or ""
                if label == 'IMPRESSION':
                    impression = abstract.text or ""

            text = (findings + " " + impression).strip().replace("\n", " ")
            if not text:
                continue

            samples.append((img_name, text))
            count_match += 1

        print(f"✅ Total images scanned: {count_img}")
        print(f"✅ Total matched samples with report: {count_match}")
        print(f"⚠️  Total images missing report: {count_missing_report}")
        return samples



    def __len__(self):
        return len(self.samples)

    def __getitem__(self, idx):
        img_name, text = self.samples[idx]
        img_path = os.path.join(self.image_dir, img_name)


        image = Image.open(img_path).convert("RGB")

        # transform to Tensor
        image = self.transform(image)

        tokens = self.tokenizer(
            text,
            padding='max_length',
            truncation=True,
            max_length=self.max_length,
            return_tensors="pt"
        )

        return {
            'image': image,  # Tensor
            'input_ids': tokens.input_ids.squeeze(0),
            'attention_mask': tokens.attention_mask.squeeze(0)
        }

## Create Dataset and DataLoader

In [11]:
# create the complete  Dataset

full_dataset = IUXRayDataset(
    image_dir=image_dir,
    report_dir=report_dir,
    tokenizer=tokenizer,
    max_length=100  # control the length of tokenizer
)


===> Start loading samples from report and image...


Parsing image-report pairs: 100%|██████████| 2065/2065 [01:19<00:00, 25.89it/s] 

✅ Total images scanned: 2063
✅ Total matched samples with report: 2054
⚠️  Total images missing report: 0





In [12]:
print(full_dataset[0])

{'image': tensor([[[0.2000, 0.1882, 0.1843,  ..., 0.7843, 0.7922, 0.7922],
         [0.1922, 0.1922, 0.3059,  ..., 0.7882, 0.7843, 0.7843],
         [0.1922, 0.2000, 0.5882,  ..., 0.7804, 0.7882, 0.7804],
         ...,
         [0.7608, 0.7529, 0.7529,  ..., 0.3490, 0.3490, 0.3569],
         [0.7647, 0.7647, 0.7647,  ..., 0.4314, 0.4392, 0.4471],
         [0.7725, 0.7647, 0.7647,  ..., 0.6824, 0.6902, 0.7059]],

        [[0.2000, 0.1882, 0.1843,  ..., 0.7843, 0.7922, 0.7922],
         [0.1922, 0.1922, 0.3059,  ..., 0.7882, 0.7843, 0.7843],
         [0.1922, 0.2000, 0.5882,  ..., 0.7804, 0.7882, 0.7804],
         ...,
         [0.7608, 0.7529, 0.7529,  ..., 0.3490, 0.3490, 0.3569],
         [0.7647, 0.7647, 0.7647,  ..., 0.4314, 0.4392, 0.4471],
         [0.7725, 0.7647, 0.7647,  ..., 0.6824, 0.6902, 0.7059]],

        [[0.2000, 0.1882, 0.1843,  ..., 0.7843, 0.7922, 0.7922],
         [0.1922, 0.1922, 0.3059,  ..., 0.7882, 0.7843, 0.7843],
         [0.1922, 0.2000, 0.5882,  ..., 0.7804, 

##Summary:
every image has a report, but not every report can find a matching image

## Test the dataloader and pair the image and report


In [13]:
import torch
from torch.utils.data import Subset

# subtract the first 10 samples
small_dataset = Subset(full_dataset, list(range(10)))

print(f"Amount of small samples: {len(small_dataset)}")




# built the  DataLoader
small_loader = DataLoader(small_dataset, batch_size=2, shuffle=True)

# testing
for batch in small_loader:
    print(batch['image'].shape)          # ✅ [B, 3, H, W]
    print(batch['input_ids'].shape)      # ✅ [B, T]
    print(batch['attention_mask'].shape) # ✅ [B, T]
    break


Amount of small samples: 10
torch.Size([2, 3, 224, 224])
torch.Size([2, 100])
torch.Size([2, 100])


# Training DataLoader & validation dataloader

In [15]:


train_size = int(0.8 * len(full_dataset))
val_size = len(full_dataset) - train_size

train_dataset, val_dataset = torch.utils.data.random_split(full_dataset, [train_size, val_size])

print(f"Training dataset: {len(train_dataset)}")
print(f"Testing Dataset: {len(val_dataset)}")

# Dataloader
train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=32, shuffle=False)

Training dataset: 1643
Testing Dataset: 411


## Initialize the model
## Encoder & Decoder

In [37]:
from transformers import CLIPModel, CLIPProcessor

class EncoderCLIP(nn.Module):
    def __init__(self, embed_size, device):
        super(EncoderCLIP, self).__init__()
        self.device = device
        self.clip_model = CLIPModel.from_pretrained("openai/clip-vit-base-patch32").to(device)
        self.processor = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch32")


        self.fc = nn.Linear(self.clip_model.config.projection_dim, embed_size)
        self.bn = nn.BatchNorm1d(embed_size, momentum=0.01)

    def forward(self, images):

        inputs = self.processor(images=images, return_tensors="pt").to(self.device)
        image_features = self.clip_model.get_image_features(**inputs)  # [batch, 512]
        image_features = self.fc(image_features)                       # [batch, embed_size]
        image_features = self.bn(image_features)




        return image_features


In [38]:
class DecoderTransformer(nn.Module):
    def __init__(self, embed_size, vocab_size, num_layers=4, heads=4, ff_hidden=1024, dropout=0.3, max_len=100):
        super(DecoderTransformer, self).__init__()

        #  token embedding
        self.embed = nn.Embedding(vocab_size, embed_size)


        self.positional_encoding = PositionalEncoding(embed_size, dropout, max_len)


        decoder_layer = nn.TransformerDecoderLayer(
            d_model=embed_size,
            nhead=heads,
            dim_feedforward=ff_hidden,
            dropout=dropout,
            batch_first=True
        )

        self.transformer_decoder = nn.TransformerDecoder(decoder_layer, num_layers=num_layers)

        #  vocab size
        self.fc = nn.Linear(embed_size, vocab_size)

    def forward(self, image_features, captions):

        embeddings = self.embed(captions)
        embeddings = self.positional_encoding(embeddings)


        tgt_mask = self.generate_square_subsequent_mask(captions.size(1)).to(captions.device)


        output = self.transformer_decoder(
            tgt=embeddings,
            memory=image_features.unsqueeze(1),
            tgt_mask=tgt_mask
        )


        output = self.fc(output)
        return output

    def generate_square_subsequent_mask(self, sz):
        return torch.triu(torch.ones(sz, sz) * float('-inf'), diagonal=1)


# 位置编码模块（经典做法）
class PositionalEncoding(nn.Module):
    def __init__(self, embed_size, dropout=0.1, max_len=5000):
        super(PositionalEncoding, self).__init__()
        self.dropout = nn.Dropout(p=dropout)

        pe = torch.zeros(max_len, embed_size)
        position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)
        div_term = torch.exp(torch.arange(0, embed_size, 2).float() * (-torch.log(torch.tensor(10000.0)) / embed_size))
        pe[:, 0::2] = torch.sin(position * div_term)
        pe[:, 1::2] = torch.cos(position * div_term)

        pe = pe.unsqueeze(0)  # [1, max_len, embed]
        self.register_buffer('pe', pe)

    def forward(self, x):
        x = x + self.pe[:, :x.size(1), :]
        return self.dropout(x)


In [39]:


class ImageCaptionModel(nn.Module):
    def __init__(self, embed_size, vocab_size, device):
        super(ImageCaptionModel, self).__init__()
        self.encoder = EncoderCLIP(embed_size, device)
        self.decoder = DecoderTransformer(embed_size, vocab_size)

    def forward(self, images, captions):
        image_features = self.encoder(images)  # [batch, embed]
        outputs = self.decoder(image_features, captions)
        return outputs


In [40]:

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = ImageCaptionModel(embed_size=512, vocab_size=len(tokenizer), device=device).to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=1e-4)
criterion = nn.CrossEntropyLoss(ignore_index=tokenizer.pad_token_id)

# EXP01

### TRAIN IN A LOOP

In [16]:
import torch
import torch.nn as nn
from torch.utils.data import DataLoader
from tqdm import tqdm
import os

def train(model, train_loader, optimizer, criterion, tokenizer, device, num_epochs=10, save_path="checkpoint.pth"):
    model.to(device)
    model.train()

    for epoch in range(num_epochs):
        total_loss = 0
        loop = tqdm(train_loader, leave=True, dynamic_ncols=True)
        for batch in loop:
            images = batch['image'].to(device)
            captions = batch['input_ids'].to(device)

            optimizer.zero_grad()
            outputs = model(images, captions[:, :-1])
            loss = criterion(outputs.reshape(-1, outputs.size(-1)), captions[:, 1:].reshape(-1))
            loss.backward()
            optimizer.step()

            total_loss += loss.item()

            loop.set_description(f"Epoch [{epoch+1}/{num_epochs}]")
            loop.set_postfix(loss=loss.item())

        print(f"Epoch {epoch+1}, Loss: {total_loss:.4f}")

        # save the latest checkpoint for every epoch
        checkpoint = {
            'epoch': epoch,
            'model_state_dict': model.state_dict(),
            'optimizer_state_dict': optimizer.state_dict(),
            'loss': total_loss
        }

        os.makedirs(os.path.dirname(save_path), exist_ok=True)

        torch.save(checkpoint, save_path)
        print(f"Checkpoint saved at {save_path}")


In [19]:
train(model, train_loader, optimizer, criterion, tokenizer, device, num_epochs=10,save_path="/content/EXP01-checkpoint.pth")

Epoch [1/10]: 100%|██████████| 52/52 [09:03<00:00, 10.45s/it, loss=4.98]


Epoch 1, Loss: 329.2802
Checkpoint saved at /content/EXP01-checkpoint.pth


Epoch [2/10]: 100%|██████████| 52/52 [00:46<00:00,  1.11it/s, loss=3.42]


Epoch 2, Loss: 205.4930
Checkpoint saved at /content/EXP01-checkpoint.pth


Epoch [3/10]: 100%|██████████| 52/52 [00:47<00:00,  1.10it/s, loss=3.55]


Epoch 3, Loss: 164.9561
Checkpoint saved at /content/EXP01-checkpoint.pth


Epoch [4/10]: 100%|██████████| 52/52 [00:46<00:00,  1.11it/s, loss=2.06]


Epoch 4, Loss: 143.8303
Checkpoint saved at /content/EXP01-checkpoint.pth


Epoch [5/10]: 100%|██████████| 52/52 [00:47<00:00,  1.10it/s, loss=2.16]


Epoch 5, Loss: 130.7519
Checkpoint saved at /content/EXP01-checkpoint.pth


Epoch [6/10]: 100%|██████████| 52/52 [00:47<00:00,  1.11it/s, loss=2.02]


Epoch 6, Loss: 121.0142
Checkpoint saved at /content/EXP01-checkpoint.pth


Epoch [7/10]: 100%|██████████| 52/52 [00:46<00:00,  1.11it/s, loss=1.88]


Epoch 7, Loss: 113.5680
Checkpoint saved at /content/EXP01-checkpoint.pth


Epoch [8/10]: 100%|██████████| 52/52 [00:46<00:00,  1.11it/s, loss=2.11]


Epoch 8, Loss: 107.6193
Checkpoint saved at /content/EXP01-checkpoint.pth


Epoch [9/10]: 100%|██████████| 52/52 [00:47<00:00,  1.09it/s, loss=1.85]


Epoch 9, Loss: 103.0175
Checkpoint saved at /content/EXP01-checkpoint.pth


Epoch [10/10]: 100%|██████████| 52/52 [00:47<00:00,  1.09it/s, loss=1.85]


Epoch 10, Loss: 99.0398
Checkpoint saved at /content/EXP01-checkpoint.pth


In [17]:
import torch
import torch.nn as nn
from torch.utils.data import DataLoader
from tqdm import tqdm
import os

def train(model, train_loader, optimizer, criterion, tokenizer, device, num_epochs=10, save_path="checkpoint.pth"):
    model.to(device)

    start_epoch = 0


    if os.path.exists(save_path):
        print(f"===> Found checkpoint at {save_path}, loading ...")
        checkpoint = torch.load(save_path, map_location=device)

        model.load_state_dict(checkpoint['model_state_dict'])
        optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
        start_epoch = checkpoint['epoch'] + 1

        print(f"===> Resume training from epoch {start_epoch}")

    model.train()

    for epoch in range(start_epoch, num_epochs):
        total_loss = 0
        loop = tqdm(train_loader, leave=True)

        for batch in loop:
            images = batch['image'].to(device)
            captions = batch['input_ids'].to(device)

            optimizer.zero_grad()
            outputs = model(images, captions[:, :-1])
            loss = criterion(outputs.reshape(-1, outputs.size(-1)), captions[:, 1:].reshape(-1))
            loss.backward()
            optimizer.step()

            total_loss += loss.item()

            loop.set_description(f"Epoch [{epoch+1}/{num_epochs}]")
            loop.set_postfix(loss=loss.item())

        print(f"Epoch {epoch+1}, Loss: {total_loss:.4f}")


        checkpoint = {
            'epoch': epoch,
            'model_state_dict': model.state_dict(),
            'optimizer_state_dict': optimizer.state_dict(),
            'loss': total_loss
        }

        os.makedirs(os.path.dirname(save_path), exist_ok=True)
        torch.save(checkpoint, save_path)
        print(f"Checkpoint saved at {save_path}")


In [21]:
train(model, train_loader, optimizer, criterion, tokenizer, device, num_epochs=463,save_path="/content/EXP01-checkpoint.pth")

===> Found checkpoint at /content/EXP01-checkpoint.pth, loading ...
===> Resume training from epoch 400


Epoch [401/700]: 100%|██████████| 52/52 [00:47<00:00,  1.09it/s, loss=0.133]


Epoch 401, Loss: 7.7084
Checkpoint saved at /content/EXP01-checkpoint.pth


Epoch [402/700]: 100%|██████████| 52/52 [00:47<00:00,  1.09it/s, loss=0.139]


Epoch 402, Loss: 7.7246
Checkpoint saved at /content/EXP01-checkpoint.pth


Epoch [403/700]: 100%|██████████| 52/52 [00:48<00:00,  1.08it/s, loss=0.153]


Epoch 403, Loss: 7.6019
Checkpoint saved at /content/EXP01-checkpoint.pth


Epoch [404/700]: 100%|██████████| 52/52 [00:47<00:00,  1.10it/s, loss=0.172]


Epoch 404, Loss: 7.5792
Checkpoint saved at /content/EXP01-checkpoint.pth


Epoch [405/700]: 100%|██████████| 52/52 [00:47<00:00,  1.10it/s, loss=0.168]


Epoch 405, Loss: 7.6160
Checkpoint saved at /content/EXP01-checkpoint.pth


Epoch [406/700]: 100%|██████████| 52/52 [00:47<00:00,  1.09it/s, loss=0.175]


Epoch 406, Loss: 7.9482
Checkpoint saved at /content/EXP01-checkpoint.pth


Epoch [407/700]: 100%|██████████| 52/52 [00:48<00:00,  1.08it/s, loss=0.159]


Epoch 407, Loss: 7.6551
Checkpoint saved at /content/EXP01-checkpoint.pth


Epoch [408/700]: 100%|██████████| 52/52 [00:47<00:00,  1.09it/s, loss=0.131]


Epoch 408, Loss: 7.4425
Checkpoint saved at /content/EXP01-checkpoint.pth


Epoch [409/700]: 100%|██████████| 52/52 [00:47<00:00,  1.09it/s, loss=0.159]


Epoch 409, Loss: 7.4935
Checkpoint saved at /content/EXP01-checkpoint.pth


Epoch [410/700]: 100%|██████████| 52/52 [00:47<00:00,  1.10it/s, loss=0.128]


Epoch 410, Loss: 7.5729
Checkpoint saved at /content/EXP01-checkpoint.pth


Epoch [411/700]: 100%|██████████| 52/52 [00:50<00:00,  1.03it/s, loss=0.134]


Epoch 411, Loss: 7.4788
Checkpoint saved at /content/EXP01-checkpoint.pth


Epoch [412/700]: 100%|██████████| 52/52 [00:47<00:00,  1.09it/s, loss=0.149]


Epoch 412, Loss: 7.5341
Checkpoint saved at /content/EXP01-checkpoint.pth


Epoch [413/700]: 100%|██████████| 52/52 [00:47<00:00,  1.10it/s, loss=0.142]


Epoch 413, Loss: 7.4374
Checkpoint saved at /content/EXP01-checkpoint.pth


Epoch [414/700]: 100%|██████████| 52/52 [00:47<00:00,  1.09it/s, loss=0.146]


Epoch 414, Loss: 7.3165
Checkpoint saved at /content/EXP01-checkpoint.pth


Epoch [415/700]: 100%|██████████| 52/52 [00:47<00:00,  1.10it/s, loss=0.152]


Epoch 415, Loss: 7.2584
Checkpoint saved at /content/EXP01-checkpoint.pth


Epoch [416/700]: 100%|██████████| 52/52 [00:47<00:00,  1.09it/s, loss=0.121]


Epoch 416, Loss: 7.2213
Checkpoint saved at /content/EXP01-checkpoint.pth


Epoch [417/700]: 100%|██████████| 52/52 [00:46<00:00,  1.11it/s, loss=0.161]


Epoch 417, Loss: 7.3413
Checkpoint saved at /content/EXP01-checkpoint.pth


Epoch [418/700]: 100%|██████████| 52/52 [00:46<00:00,  1.11it/s, loss=0.187]


Epoch 418, Loss: 7.4693
Checkpoint saved at /content/EXP01-checkpoint.pth


Epoch [419/700]: 100%|██████████| 52/52 [00:46<00:00,  1.11it/s, loss=0.131]


Epoch 419, Loss: 7.3630
Checkpoint saved at /content/EXP01-checkpoint.pth


Epoch [420/700]: 100%|██████████| 52/52 [00:47<00:00,  1.10it/s, loss=0.143]


Epoch 420, Loss: 7.1609
Checkpoint saved at /content/EXP01-checkpoint.pth


Epoch [421/700]: 100%|██████████| 52/52 [00:46<00:00,  1.11it/s, loss=0.15]


Epoch 421, Loss: 7.2625
Checkpoint saved at /content/EXP01-checkpoint.pth


Epoch [422/700]: 100%|██████████| 52/52 [00:46<00:00,  1.11it/s, loss=0.149]


Epoch 422, Loss: 7.3415
Checkpoint saved at /content/EXP01-checkpoint.pth


Epoch [423/700]: 100%|██████████| 52/52 [00:46<00:00,  1.11it/s, loss=0.106]


Epoch 423, Loss: 7.1600
Checkpoint saved at /content/EXP01-checkpoint.pth


Epoch [424/700]: 100%|██████████| 52/52 [00:46<00:00,  1.11it/s, loss=0.118]


Epoch 424, Loss: 7.1742
Checkpoint saved at /content/EXP01-checkpoint.pth


Epoch [425/700]: 100%|██████████| 52/52 [00:46<00:00,  1.11it/s, loss=0.146]


Epoch 425, Loss: 7.1882
Checkpoint saved at /content/EXP01-checkpoint.pth


Epoch [426/700]: 100%|██████████| 52/52 [00:47<00:00,  1.10it/s, loss=0.155]


Epoch 426, Loss: 7.2023
Checkpoint saved at /content/EXP01-checkpoint.pth


Epoch [427/700]: 100%|██████████| 52/52 [00:46<00:00,  1.11it/s, loss=0.152]


Epoch 427, Loss: 7.1173
Checkpoint saved at /content/EXP01-checkpoint.pth


Epoch [428/700]: 100%|██████████| 52/52 [00:46<00:00,  1.11it/s, loss=0.15]


Epoch 428, Loss: 7.1245
Checkpoint saved at /content/EXP01-checkpoint.pth


Epoch [429/700]: 100%|██████████| 52/52 [00:47<00:00,  1.10it/s, loss=0.111]


Epoch 429, Loss: 7.2676
Checkpoint saved at /content/EXP01-checkpoint.pth


Epoch [430/700]: 100%|██████████| 52/52 [00:47<00:00,  1.10it/s, loss=0.213]


Epoch 430, Loss: 7.2355
Checkpoint saved at /content/EXP01-checkpoint.pth


Epoch [431/700]: 100%|██████████| 52/52 [00:46<00:00,  1.11it/s, loss=0.151]


Epoch 431, Loss: 7.1053
Checkpoint saved at /content/EXP01-checkpoint.pth


Epoch [432/700]: 100%|██████████| 52/52 [00:46<00:00,  1.11it/s, loss=0.127]


Epoch 432, Loss: 7.0882
Checkpoint saved at /content/EXP01-checkpoint.pth


Epoch [433/700]: 100%|██████████| 52/52 [00:47<00:00,  1.09it/s, loss=0.115]


Epoch 433, Loss: 7.1453
Checkpoint saved at /content/EXP01-checkpoint.pth


Epoch [434/700]: 100%|██████████| 52/52 [00:46<00:00,  1.11it/s, loss=0.165]


Epoch 434, Loss: 7.1716
Checkpoint saved at /content/EXP01-checkpoint.pth


Epoch [435/700]: 100%|██████████| 52/52 [00:46<00:00,  1.11it/s, loss=0.116]


Epoch 435, Loss: 7.0553
Checkpoint saved at /content/EXP01-checkpoint.pth


Epoch [436/700]: 100%|██████████| 52/52 [00:46<00:00,  1.11it/s, loss=0.151]


Epoch 436, Loss: 6.9555
Checkpoint saved at /content/EXP01-checkpoint.pth


Epoch [437/700]: 100%|██████████| 52/52 [00:47<00:00,  1.11it/s, loss=0.111]


Epoch 437, Loss: 6.8811
Checkpoint saved at /content/EXP01-checkpoint.pth


Epoch [438/700]: 100%|██████████| 52/52 [00:46<00:00,  1.11it/s, loss=0.142]


Epoch 438, Loss: 6.9964
Checkpoint saved at /content/EXP01-checkpoint.pth


Epoch [439/700]: 100%|██████████| 52/52 [00:46<00:00,  1.12it/s, loss=0.127]


Epoch 439, Loss: 7.0732
Checkpoint saved at /content/EXP01-checkpoint.pth


Epoch [440/700]: 100%|██████████| 52/52 [00:46<00:00,  1.11it/s, loss=0.127]


Epoch 440, Loss: 7.0233
Checkpoint saved at /content/EXP01-checkpoint.pth


Epoch [441/700]: 100%|██████████| 52/52 [00:47<00:00,  1.11it/s, loss=0.141]


Epoch 441, Loss: 6.9454
Checkpoint saved at /content/EXP01-checkpoint.pth


Epoch [442/700]: 100%|██████████| 52/52 [00:46<00:00,  1.12it/s, loss=0.145]


Epoch 442, Loss: 6.9663
Checkpoint saved at /content/EXP01-checkpoint.pth


Epoch [443/700]: 100%|██████████| 52/52 [00:47<00:00,  1.10it/s, loss=0.15]


Epoch 443, Loss: 7.0423
Checkpoint saved at /content/EXP01-checkpoint.pth


Epoch [444/700]: 100%|██████████| 52/52 [00:47<00:00,  1.10it/s, loss=0.135]


Epoch 444, Loss: 6.9022
Checkpoint saved at /content/EXP01-checkpoint.pth


Epoch [445/700]: 100%|██████████| 52/52 [00:47<00:00,  1.10it/s, loss=0.112]


Epoch 445, Loss: 7.0638
Checkpoint saved at /content/EXP01-checkpoint.pth


Epoch [446/700]: 100%|██████████| 52/52 [00:47<00:00,  1.10it/s, loss=0.161]


Epoch 446, Loss: 6.9650
Checkpoint saved at /content/EXP01-checkpoint.pth


Epoch [447/700]: 100%|██████████| 52/52 [00:47<00:00,  1.09it/s, loss=0.123]


Epoch 447, Loss: 6.8518
Checkpoint saved at /content/EXP01-checkpoint.pth


Epoch [448/700]: 100%|██████████| 52/52 [00:47<00:00,  1.10it/s, loss=0.118]


Epoch 448, Loss: 6.8284
Checkpoint saved at /content/EXP01-checkpoint.pth


Epoch [449/700]: 100%|██████████| 52/52 [00:47<00:00,  1.09it/s, loss=0.154]


Epoch 449, Loss: 7.1289
Checkpoint saved at /content/EXP01-checkpoint.pth


Epoch [450/700]: 100%|██████████| 52/52 [00:47<00:00,  1.10it/s, loss=0.134]


Epoch 450, Loss: 7.0371
Checkpoint saved at /content/EXP01-checkpoint.pth


Epoch [451/700]: 100%|██████████| 52/52 [00:46<00:00,  1.11it/s, loss=0.127]


Epoch 451, Loss: 6.9866
Checkpoint saved at /content/EXP01-checkpoint.pth


Epoch [452/700]: 100%|██████████| 52/52 [00:46<00:00,  1.12it/s, loss=0.148]


Epoch 452, Loss: 6.9401
Checkpoint saved at /content/EXP01-checkpoint.pth


Epoch [453/700]: 100%|██████████| 52/52 [00:46<00:00,  1.11it/s, loss=0.164]


Epoch 453, Loss: 6.9256
Checkpoint saved at /content/EXP01-checkpoint.pth


Epoch [454/700]: 100%|██████████| 52/52 [00:46<00:00,  1.11it/s, loss=0.23]


Epoch 454, Loss: 9.6422
Checkpoint saved at /content/EXP01-checkpoint.pth


Epoch [455/700]: 100%|██████████| 52/52 [00:47<00:00,  1.10it/s, loss=0.264]


Epoch 455, Loss: 10.9789
Checkpoint saved at /content/EXP01-checkpoint.pth


Epoch [456/700]: 100%|██████████| 52/52 [00:46<00:00,  1.12it/s, loss=0.205]


Epoch 456, Loss: 9.9197
Checkpoint saved at /content/EXP01-checkpoint.pth


Epoch [457/700]: 100%|██████████| 52/52 [00:47<00:00,  1.10it/s, loss=0.163]


Epoch 457, Loss: 9.4037
Checkpoint saved at /content/EXP01-checkpoint.pth


Epoch [458/700]: 100%|██████████| 52/52 [00:46<00:00,  1.11it/s, loss=0.177]


Epoch 458, Loss: 8.9891
Checkpoint saved at /content/EXP01-checkpoint.pth


Epoch [459/700]: 100%|██████████| 52/52 [00:46<00:00,  1.12it/s, loss=0.189]


Epoch 459, Loss: 8.5533
Checkpoint saved at /content/EXP01-checkpoint.pth


Epoch [460/700]: 100%|██████████| 52/52 [00:47<00:00,  1.10it/s, loss=0.153]


Epoch 460, Loss: 8.2169
Checkpoint saved at /content/EXP01-checkpoint.pth


Epoch [461/700]: 100%|██████████| 52/52 [00:47<00:00,  1.10it/s, loss=0.186]


Epoch 461, Loss: 8.1197
Checkpoint saved at /content/EXP01-checkpoint.pth


Epoch [462/700]:  44%|████▍     | 23/52 [00:21<00:27,  1.06it/s, loss=0.135]
ERROR:root:Internal Python error in the inspect module.
Below is the traceback from this internal error.


KeyboardInterrupt



## Evaluation

In [25]:
import torch
from tqdm import tqdm

def evaluate(model, test_loader, tokenizer, device, max_length=100, num_samples=5):
    model.eval()
    model.to(device)

    results = []

    eos_token_id = tokenizer.eos_token_id or tokenizer.sep_token_id or tokenizer.pad_token_id

    with torch.no_grad():
        for i, batch in enumerate(tqdm(test_loader)):
            images = batch['image'].to(device)
            gt_reports = tokenizer.batch_decode(batch['input_ids'], skip_special_tokens=True)


            start_token = tokenizer.bos_token_id or tokenizer.cls_token_id or 0
            input_ids = torch.full((images.size(0), 1), start_token, dtype=torch.long).to(device)

            # greedy decoding
            for _ in range(max_length):
                outputs = model(images, input_ids)           # [B, T, V]
                next_token = outputs[:, -1, :].argmax(dim=-1, keepdim=True)  # [B, 1]
                input_ids = torch.cat([input_ids, next_token], dim=1)

                if eos_token_id is not None:
                    if torch.eq(next_token, eos_token_id).all():
                        break

            # decode
            generated_reports = tokenizer.batch_decode(input_ids, skip_special_tokens=True)

            for gt, pred in zip(gt_reports, generated_reports):
                results.append((gt.strip(), pred.strip()))


            if len(results) >= num_samples:
                break

    print("\n=== Sample Evaluation ===")
    for idx, (gt, pred) in enumerate(results[:num_samples]):
        print(f"\n🔹 Sample {idx+1}")
        print(f"🟢 Ground Truth:\n{gt}")
        print(f"🔵 Prediction:\n{pred}")

    return results


In [26]:
print(f"EOS token ID: {tokenizer.eos_token_id}")

EOS token ID: None


In [27]:

evaluate(model, val_loader, tokenizer, device, num_samples=5)


  0%|          | 0/13 [00:25<?, ?it/s]


=== Sample Evaluation ===

🔹 Sample 1
🟢 Ground Truth:
the cardiac silhouette and mediastinum size are within normal limits. there is no pulmonary edema. there is no focal consolidation. there are no xxxx of a pleural effusion. there is no evidence of pneumothorax. multilevel flowing anterior thoracic spine osteophytes, which could represent changes of diffuse idiopathic skeletal hyperostosis ( dish ). there is no evidence of acute cardiopulmonary disease.
🔵 Prediction:
media media media media media media media media media media media media media media media media media media media media media media media media media media media media media media media media media media media media media media media media media media media media media media media media media media media media media media media media media media media media media media media media media media media media media media media media media media media media media media media media media media media media media media media med




[('the cardiac silhouette and mediastinum size are within normal limits. there is no pulmonary edema. there is no focal consolidation. there are no xxxx of a pleural effusion. there is no evidence of pneumothorax. multilevel flowing anterior thoracic spine osteophytes, which could represent changes of diffuse idiopathic skeletal hyperostosis ( dish ). there is no evidence of acute cardiopulmonary disease.',
  'media media media media media media media media media media media media media media media media media media media media media media media media media media media media media media media media media media media media media media media media media media media media media media media media media media media media media media media media media media media media media media media media media media media media media media media media media media media media media media media media media media media media media media media media media media media media media media media media media medi

# EXP02

## Encoder

In [51]:
from transformers import CLIPModel, CLIPProcessor


class EncoderCLIP(nn.Module):
    def __init__(self, embed_size, device):
        super(EncoderCLIP, self).__init__()
        self.device = device
        self.clip_model = CLIPModel.from_pretrained("openai/clip-vit-base-patch32").to(device)
        self.processor = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch32")

        self.fc = nn.Linear(768, embed_size)
        self.bn = nn.BatchNorm1d(embed_size, momentum=0.01)

    def forward(self, images):
        inputs = self.processor(images=images, return_tensors="pt").to(self.device)
        vision_outputs = self.clip_model.vision_model(**inputs)
        patch_embeddings = vision_outputs.last_hidden_state  # [B, seq_len, 768]

        projected = self.fc(patch_embeddings)  # → [B, seq_len, embed]
        projected = self.bn(projected.transpose(1, 2)).transpose(1, 2)

        return projected  # [B, seq_len, embed]



## Decoder

In [52]:
import torch
import torch.nn as nn

class DecoderTransformer(nn.Module):
    def __init__(self, embed_size, vocab_size, num_layers=4, heads=4, ff_hidden=1024, dropout=0.3, max_len=100):
        super(DecoderTransformer, self).__init__()

        self.embed = nn.Embedding(vocab_size, embed_size)
        self.positional_encoding = PositionalEncoding(embed_size, dropout, max_len)

        decoder_layer = nn.TransformerDecoderLayer(
            d_model=embed_size,
            nhead=heads,
            dim_feedforward=ff_hidden,
            dropout=dropout,
            batch_first=True
        )

        self.transformer_decoder = nn.TransformerDecoder(decoder_layer, num_layers=num_layers)
        self.fc = nn.Linear(embed_size, vocab_size)

    def forward(self, image_features, captions):
        # captions: [B, T] (input_ids)
        embeddings = self.embed(captions)
        embeddings = self.positional_encoding(embeddings)

        tgt_mask = self.generate_square_subsequent_mask(captions.size(1)).to(captions.device)

        output = self.transformer_decoder(
            tgt=embeddings,
            memory=image_features,
            tgt_mask=tgt_mask
        )

        return self.fc(output)

    def generate_square_subsequent_mask(self, sz):
        return torch.triu(torch.ones(sz, sz) * float('-inf'), diagonal=1)


class PositionalEncoding(nn.Module):
    def __init__(self, embed_size, dropout=0.1, max_len=5000):
        super(PositionalEncoding, self).__init__()
        self.dropout = nn.Dropout(p=dropout)

        pe = torch.zeros(max_len, embed_size)
        position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)
        div_term = torch.exp(torch.arange(0, embed_size, 2).float() * (-torch.log(torch.tensor(10000.0)) / embed_size))
        pe[:, 0::2] = torch.sin(position * div_term)
        pe[:, 1::2] = torch.cos(position * div_term)

        pe = pe.unsqueeze(0)
        self.register_buffer('pe', pe)

    def forward(self, x):
        x = x + self.pe[:, :x.size(1), :]
        return self.dropout(x)


In [53]:
class ImageCaptionModel(nn.Module):
    def __init__(self, embed_size, vocab_size, device):
        super(ImageCaptionModel, self).__init__()
        self.encoder = EncoderCLIP(embed_size, device)
        self.decoder = DecoderTransformer(embed_size, vocab_size)

    def forward(self, images, captions):
        image_features = self.encoder(images)  # [batch, embed]
        outputs = self.decoder(image_features, captions)
        return outputs


## Initialize model

In [54]:

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = ImageCaptionModel(embed_size=512, vocab_size=len(tokenizer), device=device).to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=1e-4)
criterion = nn.CrossEntropyLoss(ignore_index=tokenizer.pad_token_id)

## Train In a Loop

In [55]:
import torch
import torch.nn as nn
from torch.utils.data import DataLoader
from tqdm import tqdm
import os

def train(model, train_loader, optimizer, criterion, tokenizer, device, num_epochs=10, save_path="checkpoint.pth"):

    model.to(device)
    start_epoch = 0


    if os.path.exists(save_path):
        print(f"==> Found checkpoint at {save_path}, loading...")
        checkpoint = torch.load(save_path, map_location=device)
        model.load_state_dict(checkpoint['model_state_dict'])
        optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
        start_epoch = checkpoint['epoch'] + 1
        print(f"==> Resuming training from epoch {start_epoch}")
    model.train()

    for epoch in range(start_epoch,num_epochs):
        total_loss = 0
        loop = tqdm(train_loader, leave=True, dynamic_ncols=True)
        for batch in loop:
            images = batch['image'].to(device)
            captions = batch['input_ids'].to(device)

            optimizer.zero_grad()
            outputs = model(images, captions[:, :-1])
            loss = criterion(outputs.reshape(-1, outputs.size(-1)), captions[:, 1:].reshape(-1))
            loss.backward()
            optimizer.step()

            total_loss += loss.item()

            loop.set_description(f"Epoch [{epoch+1}/{num_epochs}]")
            loop.set_postfix(loss=loss.item())

        print(f"Epoch {epoch+1}, Loss: {total_loss:.4f}")


        checkpoint = {
            'epoch': epoch,
            'model_state_dict': model.state_dict(),
            'optimizer_state_dict': optimizer.state_dict(),
            'loss': total_loss
        }


        os.makedirs(os.path.dirname(save_path), exist_ok=True)

        torch.save(checkpoint, save_path)
        print(f"Checkpoint saved at {save_path}")


In [78]:
train(model, train_loader, optimizer, criterion, tokenizer, device, num_epochs=100,save_path="/content/drive/MyDrive/EXP02-checkpoint.pth")

==> Found checkpoint at /content/EXP02-checkpoint.pth, loading...
==> Resuming training from epoch 10


Epoch [11/100]: 100%|██████████| 52/52 [00:48<00:00,  1.08it/s, loss=0.574]


Epoch 11, Loss: 28.9077
Checkpoint saved at /content/EXP02-checkpoint.pth


Epoch [12/100]: 100%|██████████| 52/52 [00:48<00:00,  1.08it/s, loss=0.395]


Epoch 12, Loss: 28.4044
Checkpoint saved at /content/EXP02-checkpoint.pth


Epoch [13/100]: 100%|██████████| 52/52 [00:48<00:00,  1.08it/s, loss=0.743]


Epoch 13, Loss: 28.4104
Checkpoint saved at /content/EXP02-checkpoint.pth


Epoch [14/100]: 100%|██████████| 52/52 [00:48<00:00,  1.07it/s, loss=0.657]


Epoch 14, Loss: 28.3157
Checkpoint saved at /content/EXP02-checkpoint.pth


Epoch [15/100]: 100%|██████████| 52/52 [00:48<00:00,  1.07it/s, loss=0.432]


Epoch 15, Loss: 27.9435
Checkpoint saved at /content/EXP02-checkpoint.pth


Epoch [16/100]: 100%|██████████| 52/52 [00:48<00:00,  1.08it/s, loss=0.554]


Epoch 16, Loss: 27.7431
Checkpoint saved at /content/EXP02-checkpoint.pth


Epoch [17/100]: 100%|██████████| 52/52 [00:48<00:00,  1.06it/s, loss=0.604]


Epoch 17, Loss: 27.5371
Checkpoint saved at /content/EXP02-checkpoint.pth


Epoch [18/100]: 100%|██████████| 52/52 [00:48<00:00,  1.08it/s, loss=0.522]


Epoch 18, Loss: 27.3955
Checkpoint saved at /content/EXP02-checkpoint.pth


Epoch [19/100]: 100%|██████████| 52/52 [00:48<00:00,  1.07it/s, loss=0.44]


Epoch 19, Loss: 27.1006
Checkpoint saved at /content/EXP02-checkpoint.pth


Epoch [20/100]: 100%|██████████| 52/52 [00:48<00:00,  1.07it/s, loss=0.487]


Epoch 20, Loss: 26.8072
Checkpoint saved at /content/EXP02-checkpoint.pth


Epoch [21/100]: 100%|██████████| 52/52 [00:48<00:00,  1.07it/s, loss=0.548]


Epoch 21, Loss: 26.7822
Checkpoint saved at /content/EXP02-checkpoint.pth


Epoch [22/100]: 100%|██████████| 52/52 [00:48<00:00,  1.07it/s, loss=0.527]


Epoch 22, Loss: 26.5536
Checkpoint saved at /content/EXP02-checkpoint.pth


Epoch [23/100]: 100%|██████████| 52/52 [00:48<00:00,  1.08it/s, loss=0.542]


Epoch 23, Loss: 26.3718
Checkpoint saved at /content/EXP02-checkpoint.pth


Epoch [24/100]: 100%|██████████| 52/52 [00:48<00:00,  1.07it/s, loss=0.585]


Epoch 24, Loss: 26.1793
Checkpoint saved at /content/EXP02-checkpoint.pth


Epoch [25/100]: 100%|██████████| 52/52 [00:49<00:00,  1.06it/s, loss=0.442]


Epoch 25, Loss: 26.0489
Checkpoint saved at /content/EXP02-checkpoint.pth


Epoch [26/100]: 100%|██████████| 52/52 [00:48<00:00,  1.07it/s, loss=0.485]


Epoch 26, Loss: 25.6275
Checkpoint saved at /content/EXP02-checkpoint.pth


Epoch [27/100]: 100%|██████████| 52/52 [00:48<00:00,  1.08it/s, loss=0.495]


Epoch 27, Loss: 25.4438
Checkpoint saved at /content/EXP02-checkpoint.pth


Epoch [28/100]: 100%|██████████| 52/52 [00:48<00:00,  1.07it/s, loss=0.564]


Epoch 28, Loss: 25.4268
Checkpoint saved at /content/EXP02-checkpoint.pth


Epoch [29/100]: 100%|██████████| 52/52 [00:48<00:00,  1.07it/s, loss=0.532]


Epoch 29, Loss: 25.0255
Checkpoint saved at /content/EXP02-checkpoint.pth


Epoch [30/100]: 100%|██████████| 52/52 [00:48<00:00,  1.07it/s, loss=0.705]


Epoch 30, Loss: 25.2419
Checkpoint saved at /content/EXP02-checkpoint.pth


Epoch [31/100]: 100%|██████████| 52/52 [00:48<00:00,  1.07it/s, loss=0.59]


Epoch 31, Loss: 24.8559
Checkpoint saved at /content/EXP02-checkpoint.pth


Epoch [32/100]: 100%|██████████| 52/52 [00:48<00:00,  1.07it/s, loss=0.497]


Epoch 32, Loss: 24.6131
Checkpoint saved at /content/EXP02-checkpoint.pth


Epoch [33/100]: 100%|██████████| 52/52 [00:48<00:00,  1.07it/s, loss=0.593]


Epoch 33, Loss: 24.6128
Checkpoint saved at /content/EXP02-checkpoint.pth


Epoch [34/100]: 100%|██████████| 52/52 [00:48<00:00,  1.08it/s, loss=0.508]


Epoch 34, Loss: 24.4807
Checkpoint saved at /content/EXP02-checkpoint.pth


Epoch [35/100]: 100%|██████████| 52/52 [00:48<00:00,  1.07it/s, loss=0.528]


Epoch 35, Loss: 24.1983
Checkpoint saved at /content/EXP02-checkpoint.pth


Epoch [36/100]: 100%|██████████| 52/52 [00:48<00:00,  1.08it/s, loss=0.418]


Epoch 36, Loss: 24.0801
Checkpoint saved at /content/EXP02-checkpoint.pth


Epoch [37/100]: 100%|██████████| 52/52 [00:48<00:00,  1.07it/s, loss=0.57]


Epoch 37, Loss: 23.8361
Checkpoint saved at /content/EXP02-checkpoint.pth


Epoch [38/100]: 100%|██████████| 52/52 [00:48<00:00,  1.08it/s, loss=0.476]


Epoch 38, Loss: 23.6599
Checkpoint saved at /content/EXP02-checkpoint.pth


Epoch [39/100]: 100%|██████████| 52/52 [00:49<00:00,  1.05it/s, loss=0.415]


Epoch 39, Loss: 23.5542
Checkpoint saved at /content/EXP02-checkpoint.pth


Epoch [40/100]: 100%|██████████| 52/52 [00:50<00:00,  1.04it/s, loss=0.679]


Epoch 40, Loss: 23.5156
Checkpoint saved at /content/EXP02-checkpoint.pth


Epoch [41/100]: 100%|██████████| 52/52 [00:48<00:00,  1.07it/s, loss=0.477]


Epoch 41, Loss: 23.1874
Checkpoint saved at /content/EXP02-checkpoint.pth


Epoch [42/100]: 100%|██████████| 52/52 [00:47<00:00,  1.08it/s, loss=0.382]


Epoch 42, Loss: 22.9863
Checkpoint saved at /content/EXP02-checkpoint.pth


Epoch [43/100]: 100%|██████████| 52/52 [00:48<00:00,  1.08it/s, loss=0.412]


Epoch 43, Loss: 22.9649
Checkpoint saved at /content/EXP02-checkpoint.pth


Epoch [44/100]: 100%|██████████| 52/52 [00:48<00:00,  1.07it/s, loss=0.481]


Epoch 44, Loss: 22.7997
Checkpoint saved at /content/EXP02-checkpoint.pth


Epoch [45/100]: 100%|██████████| 52/52 [00:48<00:00,  1.08it/s, loss=0.36]


Epoch 45, Loss: 22.4552
Checkpoint saved at /content/EXP02-checkpoint.pth


Epoch [46/100]: 100%|██████████| 52/52 [00:47<00:00,  1.09it/s, loss=0.426]


Epoch 46, Loss: 22.4269
Checkpoint saved at /content/EXP02-checkpoint.pth


Epoch [47/100]: 100%|██████████| 52/52 [00:48<00:00,  1.07it/s, loss=0.329]


Epoch 47, Loss: 22.2950
Checkpoint saved at /content/EXP02-checkpoint.pth


Epoch [48/100]: 100%|██████████| 52/52 [00:48<00:00,  1.08it/s, loss=0.327]


Epoch 48, Loss: 22.0186
Checkpoint saved at /content/EXP02-checkpoint.pth


Epoch [49/100]: 100%|██████████| 52/52 [00:48<00:00,  1.07it/s, loss=0.444]


Epoch 49, Loss: 22.0646
Checkpoint saved at /content/EXP02-checkpoint.pth


Epoch [50/100]: 100%|██████████| 52/52 [00:48<00:00,  1.07it/s, loss=0.394]


Epoch 50, Loss: 21.8056
Checkpoint saved at /content/EXP02-checkpoint.pth


Epoch [51/100]: 100%|██████████| 52/52 [00:47<00:00,  1.10it/s, loss=0.351]


Epoch 51, Loss: 21.8302
Checkpoint saved at /content/EXP02-checkpoint.pth


Epoch [52/100]: 100%|██████████| 52/52 [00:48<00:00,  1.08it/s, loss=0.474]


Epoch 52, Loss: 21.5969
Checkpoint saved at /content/EXP02-checkpoint.pth


Epoch [53/100]: 100%|██████████| 52/52 [00:48<00:00,  1.08it/s, loss=0.315]


Epoch 53, Loss: 21.3955
Checkpoint saved at /content/EXP02-checkpoint.pth


Epoch [54/100]: 100%|██████████| 52/52 [00:48<00:00,  1.07it/s, loss=0.486]


Epoch 54, Loss: 21.5880
Checkpoint saved at /content/EXP02-checkpoint.pth


Epoch [55/100]: 100%|██████████| 52/52 [00:48<00:00,  1.08it/s, loss=0.439]


Epoch 55, Loss: 21.4413
Checkpoint saved at /content/EXP02-checkpoint.pth


Epoch [56/100]: 100%|██████████| 52/52 [00:47<00:00,  1.08it/s, loss=0.299]


Epoch 56, Loss: 20.9948
Checkpoint saved at /content/EXP02-checkpoint.pth


Epoch [57/100]: 100%|██████████| 52/52 [00:48<00:00,  1.08it/s, loss=0.465]


Epoch 57, Loss: 21.0328
Checkpoint saved at /content/EXP02-checkpoint.pth


Epoch [58/100]: 100%|██████████| 52/52 [00:48<00:00,  1.08it/s, loss=0.396]


Epoch 58, Loss: 20.8429
Checkpoint saved at /content/EXP02-checkpoint.pth


Epoch [59/100]: 100%|██████████| 52/52 [00:48<00:00,  1.07it/s, loss=0.436]


Epoch 59, Loss: 20.8651
Checkpoint saved at /content/EXP02-checkpoint.pth


Epoch [60/100]: 100%|██████████| 52/52 [00:48<00:00,  1.08it/s, loss=0.557]


Epoch 60, Loss: 20.6693
Checkpoint saved at /content/EXP02-checkpoint.pth


Epoch [61/100]: 100%|██████████| 52/52 [00:48<00:00,  1.08it/s, loss=0.327]


Epoch 61, Loss: 20.4908
Checkpoint saved at /content/EXP02-checkpoint.pth


Epoch [62/100]: 100%|██████████| 52/52 [00:48<00:00,  1.07it/s, loss=0.33]


Epoch 62, Loss: 20.3193
Checkpoint saved at /content/EXP02-checkpoint.pth


Epoch [63/100]: 100%|██████████| 52/52 [00:48<00:00,  1.08it/s, loss=0.392]


Epoch 63, Loss: 20.2109
Checkpoint saved at /content/EXP02-checkpoint.pth


Epoch [64/100]: 100%|██████████| 52/52 [00:48<00:00,  1.08it/s, loss=0.332]


Epoch 64, Loss: 20.2668
Checkpoint saved at /content/EXP02-checkpoint.pth


Epoch [65/100]: 100%|██████████| 52/52 [00:48<00:00,  1.08it/s, loss=0.507]


Epoch 65, Loss: 20.0205
Checkpoint saved at /content/EXP02-checkpoint.pth


Epoch [66/100]: 100%|██████████| 52/52 [00:48<00:00,  1.08it/s, loss=0.364]


Epoch 66, Loss: 19.9518
Checkpoint saved at /content/EXP02-checkpoint.pth


Epoch [67/100]: 100%|██████████| 52/52 [00:47<00:00,  1.09it/s, loss=0.402]


Epoch 67, Loss: 19.8675
Checkpoint saved at /content/EXP02-checkpoint.pth


Epoch [68/100]: 100%|██████████| 52/52 [00:48<00:00,  1.07it/s, loss=0.407]


Epoch 68, Loss: 19.8801
Checkpoint saved at /content/EXP02-checkpoint.pth


Epoch [69/100]: 100%|██████████| 52/52 [00:48<00:00,  1.07it/s, loss=0.365]


Epoch 69, Loss: 19.6029
Checkpoint saved at /content/EXP02-checkpoint.pth


Epoch [70/100]: 100%|██████████| 52/52 [00:48<00:00,  1.07it/s, loss=0.378]


Epoch 70, Loss: 19.5340
Checkpoint saved at /content/EXP02-checkpoint.pth


Epoch [71/100]: 100%|██████████| 52/52 [00:48<00:00,  1.08it/s, loss=0.374]


Epoch 71, Loss: 19.2853
Checkpoint saved at /content/EXP02-checkpoint.pth


Epoch [72/100]: 100%|██████████| 52/52 [00:48<00:00,  1.08it/s, loss=0.407]


Epoch 72, Loss: 19.2260
Checkpoint saved at /content/EXP02-checkpoint.pth


Epoch [73/100]: 100%|██████████| 52/52 [00:49<00:00,  1.05it/s, loss=0.295]


Epoch 73, Loss: 19.1569
Checkpoint saved at /content/EXP02-checkpoint.pth


Epoch [74/100]: 100%|██████████| 52/52 [00:48<00:00,  1.07it/s, loss=0.405]


Epoch 74, Loss: 19.0991
Checkpoint saved at /content/EXP02-checkpoint.pth


Epoch [75/100]: 100%|██████████| 52/52 [00:48<00:00,  1.07it/s, loss=0.37]


Epoch 75, Loss: 18.8321
Checkpoint saved at /content/EXP02-checkpoint.pth


Epoch [76/100]: 100%|██████████| 52/52 [00:48<00:00,  1.08it/s, loss=0.264]


Epoch 76, Loss: 18.7773
Checkpoint saved at /content/EXP02-checkpoint.pth


Epoch [77/100]: 100%|██████████| 52/52 [00:48<00:00,  1.08it/s, loss=0.412]


Epoch 77, Loss: 18.7159
Checkpoint saved at /content/EXP02-checkpoint.pth


Epoch [78/100]: 100%|██████████| 52/52 [00:47<00:00,  1.09it/s, loss=0.381]


Epoch 78, Loss: 18.8418
Checkpoint saved at /content/EXP02-checkpoint.pth


Epoch [79/100]: 100%|██████████| 52/52 [00:48<00:00,  1.07it/s, loss=0.479]


Epoch 79, Loss: 18.6382
Checkpoint saved at /content/EXP02-checkpoint.pth


Epoch [80/100]: 100%|██████████| 52/52 [00:50<00:00,  1.03it/s, loss=0.419]


Epoch 80, Loss: 18.5735
Checkpoint saved at /content/EXP02-checkpoint.pth


Epoch [81/100]: 100%|██████████| 52/52 [00:47<00:00,  1.09it/s, loss=0.374]


Epoch 81, Loss: 18.4254
Checkpoint saved at /content/EXP02-checkpoint.pth


Epoch [82/100]: 100%|██████████| 52/52 [00:48<00:00,  1.07it/s, loss=0.432]


Epoch 82, Loss: 18.3790
Checkpoint saved at /content/EXP02-checkpoint.pth


Epoch [83/100]: 100%|██████████| 52/52 [00:48<00:00,  1.08it/s, loss=0.438]


Epoch 83, Loss: 18.3704
Checkpoint saved at /content/EXP02-checkpoint.pth


Epoch [84/100]: 100%|██████████| 52/52 [00:49<00:00,  1.04it/s, loss=0.406]


Epoch 84, Loss: 18.1112
Checkpoint saved at /content/EXP02-checkpoint.pth


Epoch [85/100]: 100%|██████████| 52/52 [00:48<00:00,  1.07it/s, loss=0.376]


Epoch 85, Loss: 18.1749
Checkpoint saved at /content/EXP02-checkpoint.pth


Epoch [86/100]: 100%|██████████| 52/52 [00:48<00:00,  1.08it/s, loss=0.367]


Epoch 86, Loss: 17.9752
Checkpoint saved at /content/EXP02-checkpoint.pth


Epoch [87/100]: 100%|██████████| 52/52 [00:48<00:00,  1.08it/s, loss=0.321]


Epoch 87, Loss: 17.6461
Checkpoint saved at /content/EXP02-checkpoint.pth


Epoch [88/100]: 100%|██████████| 52/52 [00:49<00:00,  1.04it/s, loss=0.265]


Epoch 88, Loss: 17.8309
Checkpoint saved at /content/EXP02-checkpoint.pth


Epoch [89/100]: 100%|██████████| 52/52 [00:48<00:00,  1.08it/s, loss=0.323]


Epoch 89, Loss: 17.6284
Checkpoint saved at /content/EXP02-checkpoint.pth


Epoch [90/100]: 100%|██████████| 52/52 [00:48<00:00,  1.06it/s, loss=0.304]


Epoch 90, Loss: 17.7272
Checkpoint saved at /content/EXP02-checkpoint.pth


Epoch [91/100]: 100%|██████████| 52/52 [00:49<00:00,  1.06it/s, loss=0.384]


Epoch 91, Loss: 17.5709
Checkpoint saved at /content/EXP02-checkpoint.pth


Epoch [92/100]: 100%|██████████| 52/52 [00:49<00:00,  1.06it/s, loss=0.349]


Epoch 92, Loss: 17.6300
Checkpoint saved at /content/EXP02-checkpoint.pth


Epoch [93/100]: 100%|██████████| 52/52 [00:48<00:00,  1.07it/s, loss=0.332]


Epoch 93, Loss: 17.2215
Checkpoint saved at /content/EXP02-checkpoint.pth


Epoch [94/100]: 100%|██████████| 52/52 [00:48<00:00,  1.08it/s, loss=0.324]


Epoch 94, Loss: 17.3765
Checkpoint saved at /content/EXP02-checkpoint.pth


Epoch [95/100]: 100%|██████████| 52/52 [00:48<00:00,  1.08it/s, loss=0.373]


Epoch 95, Loss: 17.2139
Checkpoint saved at /content/EXP02-checkpoint.pth


Epoch [96/100]: 100%|██████████| 52/52 [00:48<00:00,  1.08it/s, loss=0.327]


Epoch 96, Loss: 17.2919
Checkpoint saved at /content/EXP02-checkpoint.pth


Epoch [97/100]: 100%|██████████| 52/52 [00:48<00:00,  1.08it/s, loss=0.396]


Epoch 97, Loss: 17.2519
Checkpoint saved at /content/EXP02-checkpoint.pth


Epoch [98/100]: 100%|██████████| 52/52 [00:48<00:00,  1.08it/s, loss=0.344]


Epoch 98, Loss: 17.0536
Checkpoint saved at /content/EXP02-checkpoint.pth


Epoch [99/100]: 100%|██████████| 52/52 [00:48<00:00,  1.08it/s, loss=0.334]


Epoch 99, Loss: 16.9223
Checkpoint saved at /content/EXP02-checkpoint.pth


Epoch [100/100]: 100%|██████████| 52/52 [00:48<00:00,  1.07it/s, loss=0.297]


Epoch 100, Loss: 16.8713
Checkpoint saved at /content/EXP02-checkpoint.pth


## Evaluation

In [40]:
import torch
from tqdm import tqdm
from nltk.translate.bleu_score import corpus_bleu

def apply_repetition_penalty(logits, input_ids, penalty=1.2):
    """
    惩罚已生成的 token，防止重复生成同一个词。
    """
    for i in range(logits.size(0)):
        for token_id in input_ids[i].tolist():
            logits[i, token_id] /= penalty
    return logits


def evaluate(model, test_loader, tokenizer, device, max_length=100, num_samples=5,
             temperature=1.0, top_k=50, repetition_penalty=1.2):
    model.eval()
    model.to(device)

    eos_token_id = tokenizer.eos_token_id or tokenizer.sep_token_id or tokenizer.pad_token_id or 0
    bos_token_id = tokenizer.bos_token_id or tokenizer.cls_token_id or 0

    results = []

    with torch.no_grad():
        for i, batch in enumerate(tqdm(test_loader)):
            images = batch['image'].to(device)
            gt_reports = tokenizer.batch_decode(batch['input_ids'], skip_special_tokens=True)

            # init: [B, 1]
            input_ids = torch.full((images.size(0), 1), bos_token_id, dtype=torch.long).to(device)

            for _ in range(max_length):
                outputs = model(images, input_ids)        # [B, T, vocab]
                logits = outputs[:, -1, :]
                logits = logits / temperature

                logits = apply_repetition_penalty(logits, input_ids, penalty=repetition_penalty)

                # Top-k sampling
                values, indices = torch.topk(logits, k=top_k, dim=-1)      # [B, K]
                probs = torch.softmax(values, dim=-1)                      # [B, K]
                sampled = torch.multinomial(probs, num_samples=1)         # [B, 1]
                next_token = indices.gather(-1, sampled)                  # [B, 1]

                input_ids = torch.cat([input_ids, next_token], dim=1)

                if eos_token_id and torch.eq(next_token, eos_token_id).all():
                    break

            predictions = tokenizer.batch_decode(input_ids, skip_special_tokens=True)

            for gt, pred in zip(gt_reports, predictions):
                results.append((gt.strip(), pred.strip()))

            if len(results) >= num_samples:
                break


    print("\n=== Sample Evaluation ===")
    for idx, (gt, pred) in enumerate(results[:num_samples]):
        print(f"\n🔹 Sample {idx+1}")
        print(f"🟢 Ground Truth:\n{gt}")
        print(f"🔵 Prediction:\n{pred}")

    # BLEU score
    pred_tokens = [pred.split() for _, pred in results]
    gt_tokens = [[gt.split()] for gt, _ in results]
    bleu_score = corpus_bleu(gt_tokens, pred_tokens)
    print(f"\n🎯 BLEU Score: {bleu_score:.4f}")

    return results, bleu_score


In [41]:

# evaluate(model, val_loader, tokenizer, device, num_samples=5)
evaluate(
    model,
    val_loader,
    tokenizer,
    device,
    num_samples=5,
    temperature=1.0,
    top_k=50,
    repetition_penalty=1.2
)

  0%|          | 0/13 [00:38<?, ?it/s]


=== Sample Evaluation ===

🔹 Sample 1
🟢 Ground Truth:
HEART: N/A MEDIASTINUM: N/A LUNGS: N/A PLEURA: N/A PNEUMOTHORAX: N/A FINDINGS: Mild cardiomegaly. Normal size and mediastinal contours. Clear lungs. No pneumothorax or pleural effusion. Unremarkable XXXX. IMPRESSION: Mild cardio
🔵 Prediction:
media to moderate images in size within normal limits. Vascular stable medial increased lung volumes with XX bron: N/A PNEUMOTHORAX: N/A FINDINGS: Stable appearance of the left base, scarring, and alvenous catheter tip overlieth atelectasis bilateral bibasilar airspace disease on the adjacent to prior study. No pneumothorax or large effusion. Left costoph

🔹 Sample 2
🟢 Ground Truth:
HEART: N/A MEDIASTINUM: N/A LUNGS: N/A PLEURA: N/A PNEUMOTHORAX: N/A FINDINGS: The cardiac and mediastinal contours are normal. The lungs are well-inflated and clear. There is no focal consolidation, pneumothorax or effusion. No acute bony abnormalities are seen. No radiopaque foreign bodies
🔵 Prediction:
superior.




([('HEART: N/A MEDIASTINUM: N/A LUNGS: N/A PLEURA: N/A PNEUMOTHORAX: N/A FINDINGS: Mild cardiomegaly. Normal size and mediastinal contours. Clear lungs. No pneumothorax or pleural effusion. Unremarkable XXXX. IMPRESSION: Mild cardio',
   'media to moderate images in size within normal limits. Vascular stable medial increased lung volumes with XX bron: N/A PNEUMOTHORAX: N/A FINDINGS: Stable appearance of the left base, scarring, and alvenous catheter tip overlieth atelectasis bilateral bibasilar airspace disease on the adjacent to prior study. No pneumothorax or large effusion. Left costoph'),
  ('HEART: N/A MEDIASTINUM: N/A LUNGS: N/A PLEURA: N/A PNEUMOTHORAX: N/A FINDINGS: The cardiac and mediastinal contours are normal. The lungs are well-inflated and clear. There is no focal consolidation, pneumothorax or effusion. No acute bony abnormalities are seen. No radiopaque foreign bodies',
   'superior. Lung volumes appear are clear bilaterally. Heart and mediastinum normal limits in conto

# EXP03

## Customize the dataset class

In [8]:
from PIL import Image
class IUXRayDataset(Dataset):
    def __init__(self, image_dir, report_dir, tokenizer, max_length=100):
        self.image_dir = image_dir
        self.report_dir = report_dir
        self.tokenizer = tokenizer
        self.max_length = max_length
        self.samples = self.load_samples()

        # process image
        self.transform = transforms.Compose([
            transforms.Resize((224, 224)),
            transforms.ToTensor()
        ])

    def load_samples(self):
        print("===> Start loading samples from report and image...")
        samples = []
        count_img = 0
        count_match = 0
        count_missing_report = 0

        for img_name in tqdm(os.listdir(self.image_dir), desc="Parsing image-report pairs"):
            if not img_name.endswith('.png'):
                continue

            count_img += 1
            base_id = img_name.split('_')[0]  # CXR1005
            report_id = base_id.replace('CXR', '')  # 1005
            xml_path = os.path.join(self.report_dir, report_id + ".xml")

            if not os.path.exists(xml_path):
                count_missing_report += 1
                continue

            try:
                tree = ET.parse(xml_path)
                root = tree.getroot()
            except:
                continue

            # extract the FINDINGS 和 IMPRESSION
            findings, impression = "", ""
            for abstract in root.findall('.//AbstractText'):
                label = abstract.attrib.get('Label')
                if label == 'FINDINGS':
                    findings = abstract.text or ""
                if label == 'IMPRESSION':
                    impression = abstract.text or ""

          # structure the data
            fields = {
                "HEART": "N/A",
                "MEDIASTINUM": "N/A",
                "LUNGS": "N/A",
                "PLEURA": "N/A",
                "PNEUMOTHORAX": "N/A",
                "FINDINGS": findings,
                "IMPRESSION": impression
            }


            text = "\n".join([f"{k}: {v if v else 'N/A'}" for k, v in fields.items()])
            if not findings and not impression:
                continue  # empty report

            samples.append((img_name, text))
            count_match += 1

        print(f"✅ Total images scanned: {count_img}")
        print(f"✅ Total matched samples with report: {count_match}")
        print(f"⚠️  Total images missing report: {count_missing_report}")
        return samples




    def __len__(self):
        return len(self.samples)

    def __getitem__(self, idx):
        img_name, text = self.samples[idx]
        img_path = os.path.join(self.image_dir, img_name)


        image = Image.open(img_path).convert("RGB")


        image = self.transform(image)

        tokens = self.tokenizer(
            text,
            padding='max_length',
            truncation=True,
            max_length=self.max_length,
            return_tensors="pt"
        )

        return {
            'image': image,  # Tensor
            'input_ids': tokens.input_ids.squeeze(0),
            'attention_mask': tokens.attention_mask.squeeze(0)
        }

## Tokenizer

In [9]:
from transformers import AutoTokenizer

tokenizer = AutoTokenizer.from_pretrained("t5-small")  # 或 t5-base，或 gpt2


The secret `HF_TOKEN` does not exist in your Colab secrets.
To authenticate with the Hugging Face Hub, create a token in your settings tab (https://huggingface.co/settings/tokens), set it as secret in your Google Colab and restart your session.
You will be able to reuse this secret in all of your notebooks.
Please note that authentication is recommended but still optional to access public models or datasets.


tokenizer_config.json:   0%|          | 0.00/2.32k [00:00<?, ?B/s]

spiece.model:   0%|          | 0.00/792k [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/1.39M [00:00<?, ?B/s]

In [10]:
tokenizer.save_pretrained("./tokenizer")

('./tokenizer/tokenizer_config.json',
 './tokenizer/special_tokens_map.json',
 './tokenizer/spiece.model',
 './tokenizer/added_tokens.json',
 './tokenizer/tokenizer.json')

## Create Dataset and DataLoader

In [11]:
# 构建完整的 Dataset（所有数据）

full_dataset = IUXRayDataset(
    image_dir=image_dir,
    report_dir=report_dir,
    tokenizer=tokenizer,
    max_length=100  # 控制报告最长多少个 token
)


===> Start loading samples from report and image...


Parsing image-report pairs: 100%|██████████| 2065/2065 [01:20<00:00, 25.75it/s] 

✅ Total images scanned: 2063
✅ Total matched samples with report: 2054
⚠️  Total images missing report: 0





In [12]:

sample = full_dataset[0]

# image tensor shape
print("Image shape:", sample["image"].shape)

# Token IDs
print(" Token IDs:", sample["input_ids"].tolist())

# Decode  to original text
decoded_text = tokenizer.decode(sample["input_ids"], skip_special_tokens=True)
print(" Decoded Text:\n", decoded_text)



Image shape: torch.Size([3, 224, 224])
 Token IDs: [3, 6021, 8241, 10, 445, 87, 188, 3, 30296, 134, 25424, 6122, 10, 445, 87, 188, 301, 25158, 134, 10, 445, 87, 188, 3, 5329, 26296, 188, 10, 445, 87, 188, 276, 4171, 6122, 6951, 6299, 4763, 4, 10, 445, 87, 188, 377, 13885, 2365, 134, 10, 37, 16216, 8172, 7, 17, 10270, 19561, 19, 441, 1389, 6790, 21, 3179, 5, 465, 15949, 844, 13, 3, 26836, 16690, 5, 71, 3, 10379, 3676, 3, 7662, 83, 32, 51, 9, 19, 4313, 16, 8, 22586, 2663, 13, 8, 646, 1364, 3, 11846, 15, 5, 3104, 75, 3676, 25049, 3, 4, 1]
 Decoded Text:
 HEART: N/A MEDIASTINUM: N/A LUNGS: N/A PLEURA: N/A PNEUMOTHORAX: N/A FINDINGS: The cardiomediastinal silhouette is within normal limits for appearance. No focal areas of pulmonary consolidation. A calcified granuloma is identified in the peripheral aspect of the left lower lobe. Calcified lymph X


##Summary:
every image has a report, but not every report can find a matching image

# Training DataLoader & validation dataloader

In [13]:

import torch

train_size = int(0.8 * len(full_dataset))
val_size = len(full_dataset) - train_size

train_dataset, val_dataset = torch.utils.data.random_split(full_dataset, [train_size, val_size])

print(f"Training dataset: {len(train_dataset)}")
print(f"Testing Dataset: {len(val_dataset)}")

# Dataloader
train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=32, shuffle=False)

Training dataset: 1643
Testing Dataset: 411


## Encoder

In [14]:
from transformers import CLIPModel, CLIPProcessor


class EncoderCLIP(nn.Module):
    def __init__(self, embed_size, device):
        super(EncoderCLIP, self).__init__()
        self.device = device
        self.clip_model = CLIPModel.from_pretrained("openai/clip-vit-base-patch32").to(device)
        self.processor = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch32")

        self.fc = nn.Linear(512, embed_size)
        self.bn = nn.BatchNorm1d(embed_size, momentum=0.01)

    def forward(self, images):
      inputs = self.processor(images=images, return_tensors="pt").to(self.device)


      image_features = self.clip_model.get_image_features(**inputs)  # [B, 512]
      projected = self.fc(image_features).unsqueeze(1)               # [B, 1, embed]

      return projected



## Decoder

In [15]:
import torch
import torch.nn as nn

class DecoderTransformer(nn.Module):
    def __init__(self, embed_size, vocab_size, num_layers=4, heads=4, ff_hidden=1024, dropout=0.3, max_len=100):
        super(DecoderTransformer, self).__init__()

        self.embed = nn.Embedding(vocab_size, embed_size)
        self.positional_encoding = PositionalEncoding(embed_size, dropout, max_len)

        decoder_layer = nn.TransformerDecoderLayer(
            d_model=embed_size,
            nhead=heads,
            dim_feedforward=ff_hidden,
            dropout=dropout,
            batch_first=True
        )

        self.transformer_decoder = nn.TransformerDecoder(decoder_layer, num_layers=num_layers)
        self.fc = nn.Linear(embed_size, vocab_size)

    def forward(self, image_features, captions):
        embeddings = self.embed(captions)
        embeddings = self.positional_encoding(embeddings)

        tgt_mask = self.generate_square_subsequent_mask(captions.size(1)).to(captions.device)

        output = self.transformer_decoder(
            tgt=embeddings,
            memory=image_features,
            tgt_mask=tgt_mask
        )

        return self.fc(output)


    def generate_square_subsequent_mask(self, sz):
        return torch.triu(torch.ones(sz, sz) * float('-inf'), diagonal=1)


class PositionalEncoding(nn.Module):
    def __init__(self, embed_size, dropout=0.1, max_len=5000):
        super(PositionalEncoding, self).__init__()
        self.dropout = nn.Dropout(p=dropout)

        pe = torch.zeros(max_len, embed_size)
        position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)
        div_term = torch.exp(torch.arange(0, embed_size, 2).float() * (-torch.log(torch.tensor(10000.0)) / embed_size))
        pe[:, 0::2] = torch.sin(position * div_term)
        pe[:, 1::2] = torch.cos(position * div_term)

        pe = pe.unsqueeze(0)
        self.register_buffer('pe', pe)

    def forward(self, x):
        x = x + self.pe[:, :x.size(1), :]
        return self.dropout(x)


In [16]:
class ImageCaptionModel(nn.Module):
    def __init__(self, embed_size, vocab_size, device):
        super(ImageCaptionModel, self).__init__()
        self.encoder = EncoderCLIP(embed_size, device)
        self.decoder = DecoderTransformer(embed_size, vocab_size)

    def forward(self, images, captions):
        image_features = self.encoder(images)  # [batch, embed]
        outputs = self.decoder(image_features, captions)
        return outputs


## Initialize the model

In [17]:

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = ImageCaptionModel(embed_size=512, vocab_size=tokenizer.vocab_size, device=device).to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=1e-4)
criterion = nn.CrossEntropyLoss(ignore_index=tokenizer.pad_token_id)

config.json:   0%|          | 0.00/4.19k [00:00<?, ?B/s]

pytorch_model.bin:   0%|          | 0.00/605M [00:00<?, ?B/s]

Using a slow image processor as `use_fast` is unset and a slow processor was saved with this model. `use_fast=True` will be the default behavior in v4.52, even if the model was saved with a slow processor. This will result in minor differences in outputs. You'll still be able to use a slow processor with `use_fast=False`.


model.safetensors:   0%|          | 0.00/605M [00:00<?, ?B/s]

preprocessor_config.json:   0%|          | 0.00/316 [00:00<?, ?B/s]

tokenizer_config.json:   0%|          | 0.00/592 [00:00<?, ?B/s]

vocab.json:   0%|          | 0.00/862k [00:00<?, ?B/s]

merges.txt:   0%|          | 0.00/525k [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/2.22M [00:00<?, ?B/s]

special_tokens_map.json:   0%|          | 0.00/389 [00:00<?, ?B/s]

## Train In loop

In [18]:
import torch
import torch.nn as nn
from torch.utils.data import DataLoader
from tqdm import tqdm
import os

def train(model, train_loader, optimizer, criterion, tokenizer, device, num_epochs=10, save_path="checkpoint.pth"):

    model.to(device)
    start_epoch = 0


    if os.path.exists(save_path):
        print(f"==> Found checkpoint at {save_path}, loading...")
        checkpoint = torch.load(save_path, map_location=device)
        model.load_state_dict(checkpoint['model_state_dict'])
        optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
        start_epoch = checkpoint['epoch'] + 1
        print(f"==> Resuming training from epoch {start_epoch}")
    model.train()

    for epoch in range(start_epoch,num_epochs):
        total_loss = 0
        loop = tqdm(train_loader, leave=True, dynamic_ncols=True)
        for batch in loop:
            images = batch['image'].to(device)
            captions = batch['input_ids'].to(device)

            optimizer.zero_grad()
            outputs = model(images, captions[:, :-1])
            loss = criterion(outputs.reshape(-1, outputs.size(-1)), captions[:, 1:].reshape(-1))
            loss.backward()
            optimizer.step()

            total_loss += loss.item()

            loop.set_description(f"Epoch [{epoch+1}/{num_epochs}]")
            loop.set_postfix(loss=loss.item())

        print(f"Epoch {epoch+1}, Loss: {total_loss:.4f}")


        checkpoint = {
            'epoch': epoch,
            'model_state_dict': model.state_dict(),
            'optimizer_state_dict': optimizer.state_dict(),
            'loss': total_loss
        }


        os.makedirs(os.path.dirname(save_path), exist_ok=True)

        torch.save(checkpoint, save_path)
        print(f"Checkpoint saved at {save_path}")


In [19]:
train(model, train_loader, optimizer, criterion, tokenizer, device, num_epochs=120,save_path="/content/drive/MyDrive/checkpoint.pth")

  0%|          | 0/52 [00:00<?, ?it/s]It looks like you are trying to rescale already rescaled images. If the input images have pixel values between 0 and 1, set `do_rescale=False` to avoid rescaling them again.
Epoch [1/120]: 100%|██████████| 52/52 [21:45<00:00, 25.11s/it, loss=3.76]


Epoch 1, Loss: 329.4586
Checkpoint saved at /content/drive/MyDrive/checkpoint.pth


Epoch [2/120]: 100%|██████████| 52/52 [00:51<00:00,  1.01it/s, loss=2.16]


Epoch 2, Loss: 142.5535
Checkpoint saved at /content/drive/MyDrive/checkpoint.pth


Epoch [3/120]: 100%|██████████| 52/52 [00:48<00:00,  1.08it/s, loss=1.96]


Epoch 3, Loss: 100.5692
Checkpoint saved at /content/drive/MyDrive/checkpoint.pth


Epoch [4/120]: 100%|██████████| 52/52 [00:49<00:00,  1.04it/s, loss=1.44]


Epoch 4, Loss: 84.1896
Checkpoint saved at /content/drive/MyDrive/checkpoint.pth


Epoch [5/120]: 100%|██████████| 52/52 [00:47<00:00,  1.09it/s, loss=1.44]


Epoch 5, Loss: 75.1815
Checkpoint saved at /content/drive/MyDrive/checkpoint.pth


Epoch [6/120]: 100%|██████████| 52/52 [00:49<00:00,  1.04it/s, loss=1.25]


Epoch 6, Loss: 68.9887
Checkpoint saved at /content/drive/MyDrive/checkpoint.pth


Epoch [7/120]: 100%|██████████| 52/52 [00:47<00:00,  1.09it/s, loss=1.11]


Epoch 7, Loss: 64.3781
Checkpoint saved at /content/drive/MyDrive/checkpoint.pth


Epoch [8/120]: 100%|██████████| 52/52 [00:50<00:00,  1.04it/s, loss=1.1]


Epoch 8, Loss: 60.7209
Checkpoint saved at /content/drive/MyDrive/checkpoint.pth


Epoch [9/120]: 100%|██████████| 52/52 [00:47<00:00,  1.08it/s, loss=1.09]


Epoch 9, Loss: 57.5075
Checkpoint saved at /content/drive/MyDrive/checkpoint.pth


Epoch [10/120]: 100%|██████████| 52/52 [00:50<00:00,  1.04it/s, loss=0.994]


Epoch 10, Loss: 54.6582
Checkpoint saved at /content/drive/MyDrive/checkpoint.pth


Epoch [11/120]: 100%|██████████| 52/52 [01:23<00:00,  1.60s/it, loss=0.977]


Epoch 11, Loss: 52.1510
Checkpoint saved at /content/drive/MyDrive/checkpoint.pth


Epoch [12/120]: 100%|██████████| 52/52 [01:23<00:00,  1.60s/it, loss=1.08]


Epoch 12, Loss: 50.0933
Checkpoint saved at /content/drive/MyDrive/checkpoint.pth


Epoch [13/120]: 100%|██████████| 52/52 [01:25<00:00,  1.65s/it, loss=0.87]


Epoch 13, Loss: 48.1421
Checkpoint saved at /content/drive/MyDrive/checkpoint.pth


Epoch [14/120]: 100%|██████████| 52/52 [01:23<00:00,  1.61s/it, loss=0.97]


Epoch 14, Loss: 46.4540
Checkpoint saved at /content/drive/MyDrive/checkpoint.pth


Epoch [15/120]: 100%|██████████| 52/52 [01:23<00:00,  1.61s/it, loss=1.14]


Epoch 15, Loss: 45.0831
Checkpoint saved at /content/drive/MyDrive/checkpoint.pth


Epoch [16/120]: 100%|██████████| 52/52 [01:23<00:00,  1.61s/it, loss=0.728]


Epoch 16, Loss: 43.4716
Checkpoint saved at /content/drive/MyDrive/checkpoint.pth


Epoch [17/120]: 100%|██████████| 52/52 [01:24<00:00,  1.62s/it, loss=0.675]


Epoch 17, Loss: 42.1490
Checkpoint saved at /content/drive/MyDrive/checkpoint.pth


Epoch [18/120]: 100%|██████████| 52/52 [01:24<00:00,  1.62s/it, loss=0.77]


Epoch 18, Loss: 41.0362
Checkpoint saved at /content/drive/MyDrive/checkpoint.pth


Epoch [19/120]: 100%|██████████| 52/52 [01:23<00:00,  1.61s/it, loss=0.685]


Epoch 19, Loss: 39.9309
Checkpoint saved at /content/drive/MyDrive/checkpoint.pth


Epoch [20/120]: 100%|██████████| 52/52 [01:23<00:00,  1.61s/it, loss=0.885]


Epoch 20, Loss: 39.1448
Checkpoint saved at /content/drive/MyDrive/checkpoint.pth


Epoch [21/120]: 100%|██████████| 52/52 [00:48<00:00,  1.08it/s, loss=0.531]


Epoch 21, Loss: 37.9319
Checkpoint saved at /content/drive/MyDrive/checkpoint.pth


Epoch [22/120]: 100%|██████████| 52/52 [01:23<00:00,  1.60s/it, loss=0.704]


Epoch 22, Loss: 37.1845
Checkpoint saved at /content/drive/MyDrive/checkpoint.pth


Epoch [23/120]: 100%|██████████| 52/52 [00:48<00:00,  1.08it/s, loss=0.748]


Epoch 23, Loss: 36.4528
Checkpoint saved at /content/drive/MyDrive/checkpoint.pth


Epoch [24/120]: 100%|██████████| 52/52 [00:47<00:00,  1.08it/s, loss=0.68]


Epoch 24, Loss: 35.6527
Checkpoint saved at /content/drive/MyDrive/checkpoint.pth


Epoch [25/120]: 100%|██████████| 52/52 [01:22<00:00,  1.58s/it, loss=0.532]


Epoch 25, Loss: 34.7970
Checkpoint saved at /content/drive/MyDrive/checkpoint.pth


Epoch [26/120]: 100%|██████████| 52/52 [00:47<00:00,  1.10it/s, loss=0.639]


Epoch 26, Loss: 34.2134
Checkpoint saved at /content/drive/MyDrive/checkpoint.pth


Epoch [27/120]: 100%|██████████| 52/52 [00:48<00:00,  1.08it/s, loss=0.545]


Epoch 27, Loss: 33.5713
Checkpoint saved at /content/drive/MyDrive/checkpoint.pth


Epoch [28/120]: 100%|██████████| 52/52 [01:23<00:00,  1.60s/it, loss=0.635]


Epoch 28, Loss: 32.9514
Checkpoint saved at /content/drive/MyDrive/checkpoint.pth


Epoch [29/120]: 100%|██████████| 52/52 [00:47<00:00,  1.08it/s, loss=0.812]


Epoch 29, Loss: 32.4743
Checkpoint saved at /content/drive/MyDrive/checkpoint.pth


Epoch [30/120]: 100%|██████████| 52/52 [00:48<00:00,  1.08it/s, loss=0.602]


Epoch 30, Loss: 31.8424
Checkpoint saved at /content/drive/MyDrive/checkpoint.pth


Epoch [31/120]: 100%|██████████| 52/52 [01:22<00:00,  1.59s/it, loss=0.361]


Epoch 31, Loss: 31.1621
Checkpoint saved at /content/drive/MyDrive/checkpoint.pth


Epoch [32/120]: 100%|██████████| 52/52 [00:48<00:00,  1.08it/s, loss=0.664]


Epoch 32, Loss: 30.8016
Checkpoint saved at /content/drive/MyDrive/checkpoint.pth


Epoch [33/120]: 100%|██████████| 52/52 [00:48<00:00,  1.07it/s, loss=0.776]


Epoch 33, Loss: 30.3704
Checkpoint saved at /content/drive/MyDrive/checkpoint.pth


Epoch [34/120]: 100%|██████████| 52/52 [00:47<00:00,  1.09it/s, loss=0.543]


Epoch 34, Loss: 29.8130
Checkpoint saved at /content/drive/MyDrive/checkpoint.pth


Epoch [35/120]: 100%|██████████| 52/52 [00:48<00:00,  1.08it/s, loss=0.624]


Epoch 35, Loss: 29.4275
Checkpoint saved at /content/drive/MyDrive/checkpoint.pth


Epoch [36/120]: 100%|██████████| 52/52 [01:25<00:00,  1.64s/it, loss=0.57]


Epoch 36, Loss: 28.9494
Checkpoint saved at /content/drive/MyDrive/checkpoint.pth


Epoch [37/120]: 100%|██████████| 52/52 [00:48<00:00,  1.07it/s, loss=0.505]


Epoch 37, Loss: 28.4728
Checkpoint saved at /content/drive/MyDrive/checkpoint.pth


Epoch [38/120]: 100%|██████████| 52/52 [01:22<00:00,  1.58s/it, loss=0.418]


Epoch 38, Loss: 28.0153
Checkpoint saved at /content/drive/MyDrive/checkpoint.pth


Epoch [39/120]: 100%|██████████| 52/52 [00:48<00:00,  1.07it/s, loss=0.494]


Epoch 39, Loss: 27.7027
Checkpoint saved at /content/drive/MyDrive/checkpoint.pth


Epoch [40/120]: 100%|██████████| 52/52 [00:48<00:00,  1.08it/s, loss=0.479]


Epoch 40, Loss: 27.2051
Checkpoint saved at /content/drive/MyDrive/checkpoint.pth


Epoch [41/120]: 100%|██████████| 52/52 [00:47<00:00,  1.10it/s, loss=0.64]


Epoch 41, Loss: 27.0912
Checkpoint saved at /content/drive/MyDrive/checkpoint.pth


Epoch [42/120]: 100%|██████████| 52/52 [00:48<00:00,  1.08it/s, loss=0.328]


Epoch 42, Loss: 26.3943
Checkpoint saved at /content/drive/MyDrive/checkpoint.pth


Epoch [43/120]: 100%|██████████| 52/52 [01:24<00:00,  1.62s/it, loss=0.561]


Epoch 43, Loss: 26.2166
Checkpoint saved at /content/drive/MyDrive/checkpoint.pth


Epoch [44/120]: 100%|██████████| 52/52 [00:47<00:00,  1.09it/s, loss=0.556]


Epoch 44, Loss: 25.9493
Checkpoint saved at /content/drive/MyDrive/checkpoint.pth


Epoch [45/120]: 100%|██████████| 52/52 [00:48<00:00,  1.07it/s, loss=0.509]


Epoch 45, Loss: 25.5302
Checkpoint saved at /content/drive/MyDrive/checkpoint.pth


Epoch [46/120]: 100%|██████████| 52/52 [01:23<00:00,  1.61s/it, loss=0.603]


Epoch 46, Loss: 25.2595
Checkpoint saved at /content/drive/MyDrive/checkpoint.pth


Epoch [47/120]: 100%|██████████| 52/52 [00:48<00:00,  1.07it/s, loss=0.558]


Epoch 47, Loss: 24.9191
Checkpoint saved at /content/drive/MyDrive/checkpoint.pth


Epoch [48/120]: 100%|██████████| 52/52 [00:47<00:00,  1.09it/s, loss=0.482]


Epoch 48, Loss: 24.5907
Checkpoint saved at /content/drive/MyDrive/checkpoint.pth


Epoch [49/120]: 100%|██████████| 52/52 [01:23<00:00,  1.60s/it, loss=0.436]


Epoch 49, Loss: 24.2566
Checkpoint saved at /content/drive/MyDrive/checkpoint.pth


Epoch [50/120]: 100%|██████████| 52/52 [00:52<00:00,  1.01s/it, loss=0.466]


Epoch 50, Loss: 23.8798
Checkpoint saved at /content/drive/MyDrive/checkpoint.pth


Epoch [51/120]: 100%|██████████| 52/52 [01:24<00:00,  1.62s/it, loss=0.483]


Epoch 51, Loss: 23.6940
Checkpoint saved at /content/drive/MyDrive/checkpoint.pth


Epoch [52/120]: 100%|██████████| 52/52 [01:24<00:00,  1.62s/it, loss=0.427]


Epoch 52, Loss: 23.3860
Checkpoint saved at /content/drive/MyDrive/checkpoint.pth


Epoch [53/120]: 100%|██████████| 52/52 [01:23<00:00,  1.60s/it, loss=0.405]


Epoch 53, Loss: 23.0196
Checkpoint saved at /content/drive/MyDrive/checkpoint.pth


Epoch [54/120]: 100%|██████████| 52/52 [01:24<00:00,  1.63s/it, loss=0.468]


Epoch 54, Loss: 22.8840
Checkpoint saved at /content/drive/MyDrive/checkpoint.pth


Epoch [55/120]: 100%|██████████| 52/52 [01:24<00:00,  1.62s/it, loss=0.467]


Epoch 55, Loss: 22.5902
Checkpoint saved at /content/drive/MyDrive/checkpoint.pth


Epoch [56/120]: 100%|██████████| 52/52 [01:22<00:00,  1.59s/it, loss=0.418]


Epoch 56, Loss: 22.2945
Checkpoint saved at /content/drive/MyDrive/checkpoint.pth


Epoch [57/120]: 100%|██████████| 52/52 [01:25<00:00,  1.64s/it, loss=0.376]


Epoch 57, Loss: 21.9984
Checkpoint saved at /content/drive/MyDrive/checkpoint.pth


Epoch [58/120]: 100%|██████████| 52/52 [01:23<00:00,  1.60s/it, loss=0.328]


Epoch 58, Loss: 21.7057
Checkpoint saved at /content/drive/MyDrive/checkpoint.pth


Epoch [59/120]: 100%|██████████| 52/52 [01:23<00:00,  1.60s/it, loss=0.478]


Epoch 59, Loss: 21.5959
Checkpoint saved at /content/drive/MyDrive/checkpoint.pth


Epoch [60/120]: 100%|██████████| 52/52 [00:48<00:00,  1.08it/s, loss=0.479]


Epoch 60, Loss: 21.3477
Checkpoint saved at /content/drive/MyDrive/checkpoint.pth


Epoch [61/120]: 100%|██████████| 52/52 [01:23<00:00,  1.61s/it, loss=0.525]


Epoch 61, Loss: 21.1203
Checkpoint saved at /content/drive/MyDrive/checkpoint.pth


Epoch [62/120]: 100%|██████████| 52/52 [00:48<00:00,  1.07it/s, loss=0.421]


Epoch 62, Loss: 20.8614
Checkpoint saved at /content/drive/MyDrive/checkpoint.pth


Epoch [63/120]: 100%|██████████| 52/52 [01:23<00:00,  1.60s/it, loss=0.33]


Epoch 63, Loss: 20.5747
Checkpoint saved at /content/drive/MyDrive/checkpoint.pth


Epoch [64/120]: 100%|██████████| 52/52 [00:47<00:00,  1.08it/s, loss=0.396]


Epoch 64, Loss: 20.3292
Checkpoint saved at /content/drive/MyDrive/checkpoint.pth


Epoch [65/120]: 100%|██████████| 52/52 [01:23<00:00,  1.60s/it, loss=0.366]


Epoch 65, Loss: 20.1749
Checkpoint saved at /content/drive/MyDrive/checkpoint.pth


Epoch [66/120]: 100%|██████████| 52/52 [00:47<00:00,  1.09it/s, loss=0.416]


Epoch 66, Loss: 19.9581
Checkpoint saved at /content/drive/MyDrive/checkpoint.pth


Epoch [67/120]: 100%|██████████| 52/52 [00:47<00:00,  1.09it/s, loss=0.413]


Epoch 67, Loss: 19.7718
Checkpoint saved at /content/drive/MyDrive/checkpoint.pth


Epoch [68/120]: 100%|██████████| 52/52 [01:21<00:00,  1.58s/it, loss=0.373]


Epoch 68, Loss: 19.4846
Checkpoint saved at /content/drive/MyDrive/checkpoint.pth


Epoch [69/120]: 100%|██████████| 52/52 [00:47<00:00,  1.09it/s, loss=0.443]


Epoch 69, Loss: 19.3690
Checkpoint saved at /content/drive/MyDrive/checkpoint.pth


Epoch [70/120]: 100%|██████████| 52/52 [00:47<00:00,  1.09it/s, loss=0.369]


Epoch 70, Loss: 19.0441
Checkpoint saved at /content/drive/MyDrive/checkpoint.pth


Epoch [71/120]: 100%|██████████| 52/52 [01:23<00:00,  1.60s/it, loss=0.303]


Epoch 71, Loss: 18.8716
Checkpoint saved at /content/drive/MyDrive/checkpoint.pth


Epoch [72/120]: 100%|██████████| 52/52 [00:48<00:00,  1.07it/s, loss=0.367]


Epoch 72, Loss: 18.7084
Checkpoint saved at /content/drive/MyDrive/checkpoint.pth


Epoch [73/120]: 100%|██████████| 52/52 [00:47<00:00,  1.09it/s, loss=0.34]


Epoch 73, Loss: 18.5144
Checkpoint saved at /content/drive/MyDrive/checkpoint.pth


Epoch [74/120]: 100%|██████████| 52/52 [00:47<00:00,  1.09it/s, loss=0.439]


Epoch 74, Loss: 18.3955
Checkpoint saved at /content/drive/MyDrive/checkpoint.pth


Epoch [75/120]: 100%|██████████| 52/52 [00:47<00:00,  1.09it/s, loss=0.267]


Epoch 75, Loss: 18.0072
Checkpoint saved at /content/drive/MyDrive/checkpoint.pth


Epoch [76/120]: 100%|██████████| 52/52 [01:23<00:00,  1.60s/it, loss=0.38]


Epoch 76, Loss: 17.9407
Checkpoint saved at /content/drive/MyDrive/checkpoint.pth


Epoch [77/120]: 100%|██████████| 52/52 [00:48<00:00,  1.07it/s, loss=0.362]


Epoch 77, Loss: 17.7889
Checkpoint saved at /content/drive/MyDrive/checkpoint.pth


Epoch [78/120]: 100%|██████████| 52/52 [00:47<00:00,  1.09it/s, loss=0.359]


Epoch 78, Loss: 17.5875
Checkpoint saved at /content/drive/MyDrive/checkpoint.pth


Epoch [79/120]: 100%|██████████| 52/52 [00:48<00:00,  1.07it/s, loss=0.244]


Epoch 79, Loss: 17.3541
Checkpoint saved at /content/drive/MyDrive/checkpoint.pth


Epoch [80/120]: 100%|██████████| 52/52 [00:48<00:00,  1.08it/s, loss=0.318]


Epoch 80, Loss: 17.2426
Checkpoint saved at /content/drive/MyDrive/checkpoint.pth


Epoch [81/120]: 100%|██████████| 52/52 [01:23<00:00,  1.61s/it, loss=0.263]


Epoch 81, Loss: 17.0522
Checkpoint saved at /content/drive/MyDrive/checkpoint.pth


Epoch [82/120]: 100%|██████████| 52/52 [00:48<00:00,  1.07it/s, loss=0.407]


Epoch 82, Loss: 16.9695
Checkpoint saved at /content/drive/MyDrive/checkpoint.pth


Epoch [83/120]: 100%|██████████| 52/52 [00:48<00:00,  1.08it/s, loss=0.35]


Epoch 83, Loss: 16.7532
Checkpoint saved at /content/drive/MyDrive/checkpoint.pth


Epoch [84/120]: 100%|██████████| 52/52 [01:24<00:00,  1.63s/it, loss=0.224]


Epoch 84, Loss: 16.4595
Checkpoint saved at /content/drive/MyDrive/checkpoint.pth


Epoch [85/120]: 100%|██████████| 52/52 [00:48<00:00,  1.08it/s, loss=0.347]


Epoch 85, Loss: 16.3812
Checkpoint saved at /content/drive/MyDrive/checkpoint.pth


Epoch [86/120]: 100%|██████████| 52/52 [00:48<00:00,  1.06it/s, loss=0.317]


Epoch 86, Loss: 16.2265
Checkpoint saved at /content/drive/MyDrive/checkpoint.pth


Epoch [87/120]: 100%|██████████| 52/52 [00:48<00:00,  1.08it/s, loss=0.341]


Epoch 87, Loss: 16.0342
Checkpoint saved at /content/drive/MyDrive/checkpoint.pth


Epoch [88/120]: 100%|██████████| 52/52 [00:47<00:00,  1.10it/s, loss=0.26]


Epoch 88, Loss: 15.9499
Checkpoint saved at /content/drive/MyDrive/checkpoint.pth


Epoch [89/120]: 100%|██████████| 52/52 [01:24<00:00,  1.62s/it, loss=0.268]


Epoch 89, Loss: 15.8390
Checkpoint saved at /content/drive/MyDrive/checkpoint.pth


Epoch [90/120]: 100%|██████████| 52/52 [00:50<00:00,  1.02it/s, loss=0.293]


Epoch 90, Loss: 15.7116
Checkpoint saved at /content/drive/MyDrive/checkpoint.pth


Epoch [91/120]: 100%|██████████| 52/52 [01:22<00:00,  1.59s/it, loss=0.266]


Epoch 91, Loss: 15.4529
Checkpoint saved at /content/drive/MyDrive/checkpoint.pth


Epoch [92/120]: 100%|██████████| 52/52 [01:25<00:00,  1.65s/it, loss=0.287]


Epoch 92, Loss: 15.2565
Checkpoint saved at /content/drive/MyDrive/checkpoint.pth


Epoch [93/120]: 100%|██████████| 52/52 [01:23<00:00,  1.60s/it, loss=0.271]


Epoch 93, Loss: 15.1797
Checkpoint saved at /content/drive/MyDrive/checkpoint.pth


Epoch [94/120]: 100%|██████████| 52/52 [01:23<00:00,  1.60s/it, loss=0.284]


Epoch 94, Loss: 15.1600
Checkpoint saved at /content/drive/MyDrive/checkpoint.pth


Epoch [95/120]: 100%|██████████| 52/52 [01:22<00:00,  1.59s/it, loss=0.304]


Epoch 95, Loss: 14.9766
Checkpoint saved at /content/drive/MyDrive/checkpoint.pth


Epoch [96/120]: 100%|██████████| 52/52 [01:24<00:00,  1.62s/it, loss=0.231]


Epoch 96, Loss: 14.8138
Checkpoint saved at /content/drive/MyDrive/checkpoint.pth


Epoch [97/120]: 100%|██████████| 52/52 [01:26<00:00,  1.66s/it, loss=0.299]


Epoch 97, Loss: 14.6624
Checkpoint saved at /content/drive/MyDrive/checkpoint.pth


Epoch [98/120]: 100%|██████████| 52/52 [01:23<00:00,  1.61s/it, loss=0.24]


Epoch 98, Loss: 14.5187
Checkpoint saved at /content/drive/MyDrive/checkpoint.pth


Epoch [99/120]: 100%|██████████| 52/52 [01:23<00:00,  1.60s/it, loss=0.227]


Epoch 99, Loss: 14.3477
Checkpoint saved at /content/drive/MyDrive/checkpoint.pth


Epoch [100/120]: 100%|██████████| 52/52 [01:21<00:00,  1.57s/it, loss=0.226]


Epoch 100, Loss: 14.2460
Checkpoint saved at /content/drive/MyDrive/checkpoint.pth


Epoch [101/120]: 100%|██████████| 52/52 [00:47<00:00,  1.08it/s, loss=0.359]


Epoch 101, Loss: 14.2200
Checkpoint saved at /content/drive/MyDrive/checkpoint.pth


Epoch [102/120]: 100%|██████████| 52/52 [01:24<00:00,  1.63s/it, loss=0.267]


Epoch 102, Loss: 14.0457
Checkpoint saved at /content/drive/MyDrive/checkpoint.pth


Epoch [103/120]: 100%|██████████| 52/52 [00:48<00:00,  1.08it/s, loss=0.223]


Epoch 103, Loss: 13.8985
Checkpoint saved at /content/drive/MyDrive/checkpoint.pth


Epoch [104/120]: 100%|██████████| 52/52 [00:48<00:00,  1.06it/s, loss=0.298]


Epoch 104, Loss: 13.8162
Checkpoint saved at /content/drive/MyDrive/checkpoint.pth


Epoch [105/120]: 100%|██████████| 52/52 [01:25<00:00,  1.64s/it, loss=0.237]


Epoch 105, Loss: 13.7402
Checkpoint saved at /content/drive/MyDrive/checkpoint.pth


Epoch [106/120]: 100%|██████████| 52/52 [00:48<00:00,  1.06it/s, loss=0.273]


Epoch 106, Loss: 13.5621
Checkpoint saved at /content/drive/MyDrive/checkpoint.pth


Epoch [107/120]: 100%|██████████| 52/52 [00:49<00:00,  1.06it/s, loss=0.242]


Epoch 107, Loss: 13.4110
Checkpoint saved at /content/drive/MyDrive/checkpoint.pth


Epoch [108/120]: 100%|██████████| 52/52 [01:25<00:00,  1.65s/it, loss=0.207]


Epoch 108, Loss: 13.2726
Checkpoint saved at /content/drive/MyDrive/checkpoint.pth


Epoch [109/120]: 100%|██████████| 52/52 [00:49<00:00,  1.05it/s, loss=0.316]


Epoch 109, Loss: 13.2642
Checkpoint saved at /content/drive/MyDrive/checkpoint.pth


Epoch [110/120]: 100%|██████████| 52/52 [01:23<00:00,  1.60s/it, loss=0.198]


Epoch 110, Loss: 12.9999
Checkpoint saved at /content/drive/MyDrive/checkpoint.pth


Epoch [111/120]: 100%|██████████| 52/52 [00:47<00:00,  1.09it/s, loss=0.238]


Epoch 111, Loss: 12.9616
Checkpoint saved at /content/drive/MyDrive/checkpoint.pth


Epoch [112/120]: 100%|██████████| 52/52 [00:49<00:00,  1.06it/s, loss=0.301]


Epoch 112, Loss: 13.0366
Checkpoint saved at /content/drive/MyDrive/checkpoint.pth


Epoch [113/120]: 100%|██████████| 52/52 [01:25<00:00,  1.64s/it, loss=0.229]


Epoch 113, Loss: 12.7065
Checkpoint saved at /content/drive/MyDrive/checkpoint.pth


Epoch [114/120]: 100%|██████████| 52/52 [00:49<00:00,  1.05it/s, loss=0.28]


Epoch 114, Loss: 12.7336
Checkpoint saved at /content/drive/MyDrive/checkpoint.pth


Epoch [115/120]: 100%|██████████| 52/52 [01:23<00:00,  1.61s/it, loss=0.219]


Epoch 115, Loss: 12.6131
Checkpoint saved at /content/drive/MyDrive/checkpoint.pth


Epoch [116/120]: 100%|██████████| 52/52 [00:48<00:00,  1.07it/s, loss=0.223]


Epoch 116, Loss: 12.4917
Checkpoint saved at /content/drive/MyDrive/checkpoint.pth


Epoch [117/120]: 100%|██████████| 52/52 [00:47<00:00,  1.10it/s, loss=0.222]


Epoch 117, Loss: 12.4584
Checkpoint saved at /content/drive/MyDrive/checkpoint.pth


Epoch [118/120]: 100%|██████████| 52/52 [00:47<00:00,  1.09it/s, loss=0.309]


Epoch 118, Loss: 12.4456
Checkpoint saved at /content/drive/MyDrive/checkpoint.pth


Epoch [119/120]: 100%|██████████| 52/52 [00:48<00:00,  1.08it/s, loss=0.23]


Epoch 119, Loss: 12.1998
Checkpoint saved at /content/drive/MyDrive/checkpoint.pth


Epoch [120/120]: 100%|██████████| 52/52 [01:24<00:00,  1.63s/it, loss=0.243]


Epoch 120, Loss: 12.1756
Checkpoint saved at /content/drive/MyDrive/checkpoint.pth


In [20]:

sample = train_dataset[0]

print("🖼️ Image shape:", sample["image"].shape)

# Token IDs
print("🧾 Token IDs:", sample["input_ids"].tolist())

decoded_text = tokenizer.decode(sample["input_ids"], skip_special_tokens=True)
print("📄 Decoded Text:\n", decoded_text)



🖼️ Image shape: torch.Size([3, 224, 224])
🧾 Token IDs: [3, 6021, 8241, 10, 445, 87, 188, 3, 30296, 134, 25424, 6122, 10, 445, 87, 188, 301, 25158, 134, 10, 445, 87, 188, 3, 5329, 26296, 188, 10, 445, 87, 188, 276, 4171, 6122, 6951, 6299, 4763, 4, 10, 445, 87, 188, 377, 13885, 2365, 134, 10, 37, 3, 4, 4, 4, 4, 6498, 3, 6848, 13, 851, 138, 11, 3, 12088, 2252, 9413, 7, 13, 8, 5738, 5, 26241, 3393, 3433, 3, 4, 4, 4, 4, 8, 3, 17, 21783, 226, 5, 37, 16216, 8172, 7, 17, 10270, 17643, 7, 33, 441, 1389, 6790, 5, 10035, 2157, 1208, 1]
📄 Decoded Text:
 HEART: N/A MEDIASTINUM: N/A LUNGS: N/A PLEURA: N/A PNEUMOTHORAX: N/A FINDINGS: The XXXX examination consists of frontal and lateral radiographs of the chest. External monitor leads XXXX the thorax. The cardiomediastinal contours are within normal limits. Pulmonary


## Evaluate

In [21]:
import torch
from tqdm import tqdm
from nltk.translate.bleu_score import corpus_bleu

def apply_repetition_penalty(logits, input_ids, penalty=1.2):

    for i in range(logits.size(0)):
        seen = set(input_ids[i].tolist())
        for token_id in seen:
            if logits[i, token_id] < 0:
                logits[i, token_id] *= penalty
            else:
                logits[i, token_id] /= penalty
    return logits

def evaluate(model, test_loader, tokenizer, device, max_length=100, num_samples=5,
             temperature=1.0, top_k=50, repetition_penalty=1.2):

    model.eval()
    model.to(device)

    eos_token_id = tokenizer.eos_token_id or tokenizer.sep_token_id or tokenizer.pad_token_id or 0
    bos_token_id = tokenizer.bos_token_id or tokenizer.cls_token_id or 0

    results = []

    with torch.no_grad():
        for batch in tqdm(test_loader):
            images = batch['image'].to(device)
            gt_reports = tokenizer.batch_decode(batch['input_ids'], skip_special_tokens=True)


            input_ids = torch.full((images.size(0), 1), bos_token_id, dtype=torch.long).to(device)

            finished = torch.zeros(images.size(0), dtype=torch.bool).to(device)

            for _ in range(max_length):
                outputs = model(images, input_ids)         # [B, T, vocab]
                logits = outputs[:, -1, :]
                logits = logits / temperature
                logits = apply_repetition_penalty(logits, input_ids, penalty=repetition_penalty)

                # Top-k sampling
                values, indices = torch.topk(logits, k=top_k, dim=-1)       # [B, K]
                probs = torch.softmax(values, dim=-1)                       # [B, K]
                sampled = torch.multinomial(probs, num_samples=1)          # [B, 1]
                next_token = indices.gather(-1, sampled)                   # [B, 1]

                input_ids = torch.cat([input_ids, next_token], dim=1)

                finished |= (next_token.squeeze() == eos_token_id)
                if finished.all():
                    break

            predictions = tokenizer.batch_decode(input_ids, skip_special_tokens=True)
            predictions = [p.strip() if p.strip() else "<EMPTY>" for p in predictions]
            gt_reports = [gt.strip() if gt.strip() else "<EMPTY>" for gt in gt_reports]

            for gt, pred in zip(gt_reports, predictions):
                results.append((gt, pred))


            if len(results) >= num_samples:
                results = results[:num_samples]
                break


    print("\n=== Sample Evaluation ===")
    for idx, (gt, pred) in enumerate(results):
        print(f"\n🔹 Sample {idx+1}")
        print(f"🟢 Ground Truth:\n{gt}")
        print(f"🔵 Prediction:\n{pred}")


    pred_tokens = [pred.split() for _, pred in results]
    gt_tokens = [[gt.split()] for gt, _ in results]
    bleu_score = corpus_bleu(gt_tokens, pred_tokens)
    print(f"\n🎯 BLEU Score: {bleu_score:.4f}")

    return results, bleu_score


In [22]:

# evaluate(model, val_loader, tokenizer, device, num_samples=5)
evaluate(
    model,
    val_loader,
    tokenizer,
    device,
    num_samples=5,
    temperature=1.0,
    top_k=50,
    repetition_penalty=1.2
)

  0%|          | 0/13 [01:03<?, ?it/s]


=== Sample Evaluation ===

🔹 Sample 1
🟢 Ground Truth:
HEART: N/A MEDIASTINUM: N/A LUNGS: N/A PLEURA: N/A PNEUMOTHORAX: N/A FINDINGS: The heart is normal in size and contour. There is no mediastinal widening. The lungs are clear bilaterally. No large pleural effusion or pneumothorax. The XXXX are intact. IM
🔵 Prediction:
ART: N/A MEDIASTINUM: N/A LUNGS: N/A PLEURA: N/A PNEUMOTHORAX: N/A FINDINGS: There are low lung volumes. The lungs are otherwise clear without focal airspace consolidation. No pleural effusion or pneumothorax. Normal heart size and mediastinal contour. There is degenerative changes of the spine. IMPRESS

🔹 Sample 2
🟢 Ground Truth:
HEART: N/A MEDIASTINUM: N/A LUNGS: N/A PLEURA: N/A PNEUMOTHORAX: N/A FINDINGS: The heart size and pulmonary vascularity appear within normal limits. The lungs are free of focal airspace disease. No pleural effusion or pneumothorax is seen. Calcified granuloma is identified.
🔵 Prediction:
ART: N/A MEDIASTINUM: N/A LUNGS: N/A PLEURA: N/A PNEUMO




([('HEART: N/A MEDIASTINUM: N/A LUNGS: N/A PLEURA: N/A PNEUMOTHORAX: N/A FINDINGS: The heart is normal in size and contour. There is no mediastinal widening. The lungs are clear bilaterally. No large pleural effusion or pneumothorax. The XXXX are intact. IM',
   'ART: N/A MEDIASTINUM: N/A LUNGS: N/A PLEURA: N/A PNEUMOTHORAX: N/A FINDINGS: There are low lung volumes. The lungs are otherwise clear without focal airspace consolidation. No pleural effusion or pneumothorax. Normal heart size and mediastinal contour. There is degenerative changes of the spine. IMPRESS'),
  ('HEART: N/A MEDIASTINUM: N/A LUNGS: N/A PLEURA: N/A PNEUMOTHORAX: N/A FINDINGS: The heart size and pulmonary vascularity appear within normal limits. The lungs are free of focal airspace disease. No pleural effusion or pneumothorax is seen. Calcified granuloma is identified.',
   'ART: N/A MEDIASTINUM: N/A LUNGS: N/A PLEURA: N/A PNEUMOTHORAX: N/A FINDINGS: Frontal and lateral views of the chest with overlying external car