In [1]:
import os, pickle, random
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader, random_split
from tqdm import tqdm

# --- Import your model components and tokenizer ---
from timm import create_model
from transformers import GPT2TokenizerFast, GPT2LMHeadModel, GPT2Config

# --- Use Albumentations for image augmentation (as in your reference code) ---
import albumentations as A
from albumentations.pytorch import ToTensorV2
from PIL import Image

# --- Set up the tokenizer ---
tokenizer = GPT2TokenizerFast.from_pretrained("gpt2")
# GPT-2 does not have a pad token so we use EOS
tokenizer.pad_token = tokenizer.eos_token

# --- Load your data (pkl file) ---
with open('updated_merge_json_200x300.pkl', 'rb') as f:
    data = pickle.load(f)

# --- Balance the dataset ---
task_classes = ['ra', 'normal']
class_counts = {cls: 0 for cls in task_classes}
data_by_class = {cls: [] for cls in task_classes}

for key, entry in data.items():
    # use lower-case class label and only include if image file exists
    class_label = entry.get('class', '').lower()
    if class_label in class_counts and os.path.exists(entry['file_path']):
        class_counts[class_label] += 1
        data_by_class[class_label].append(entry)

min_class_count = min(class_counts.values())
balanced_data = []
for cls in task_classes:
    balanced_data.extend(random.sample(data_by_class[cls], min_class_count))

# Optionally, you may multiply the balanced_data for augmentation purposes:
augmented_data = balanced_data * 2

# --- Define Albumentations transforms ---
train_tfms = A.Compose([
    A.Resize(224, 224),
    A.HorizontalFlip(p=0.5),
    A.RandomBrightnessContrast(p=0.5),
    A.Normalize(mean=[0.5,0.5,0.5], std=[0.5,0.5,0.5], always_apply=True),
    ToTensorV2()
])

valid_tfms = A.Compose([
    A.Resize(224, 224),
    A.Normalize(mean=[0.5,0.5,0.5], std=[0.5,0.5,0.5], always_apply=True),
    ToTensorV2()
])

# For patch images we use a similar transform (you can adjust if needed)
patch_tfms = A.Compose([
    A.Resize(112, 112),
    A.Normalize(mean=[0.5,0.5,0.5], std=[0.5,0.5,0.5], always_apply=True),
    ToTensorV2()
])

# --- Define a new Dataset for report generation ---
class ReportDataset(Dataset):
    def __init__(self, data, img_tfms, patch_tfms, tokenizer, max_length=128):
        self.data = data
        self.img_tfms = img_tfms
        self.patch_tfms = patch_tfms
        self.tokenizer = tokenizer
        self.max_length = max_length

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        entry = self.data[idx]
        
        # Load the main image
        image = Image.open(entry['file_path']).convert('RGB')
        image = np.array(image)
        augmented = self.img_tfms(image=image)
        image_tensor = augmented['image']
        
        # Process patch images (if any)
        patches = entry.get('bbx', [])
        if len(patches) > 0:
            # Limit to up to 34 patches (as before)
            patch_imgs = [Image.fromarray(patch) for patch in patches[:34]]
            patch_tensors = []
            for p in patch_imgs:
                p = np.array(p)
                aug_patch = self.patch_tfms(image=p)
                patch_tensors.append(aug_patch['image'])
            # Concatenate along the channel dimension: resulting shape (34*3, 112, 112)
            combined_patches = torch.cat(patch_tensors, dim=0)
        else:
            combined_patches = torch.zeros(34*3, 112, 112)
        
        # Process the report text (diagnosis field)
        # Replace unwanted tokens and append EOS token (using tokenizer.eos_token)
        report_text = entry['diagnosis'].replace('_x000D_', ' ').strip()
        caption = f"{report_text}{self.tokenizer.eos_token}"
        
        # Tokenize without padding here (we will pad in the collate function)
        encoding = self.tokenizer(
            caption,
            truncation=True,
            max_length=self.max_length,
            return_attention_mask=True,
            return_tensors="pt"
        )
        # Squeeze to remove batch dimension (now shape: [seq_len])
        input_ids = encoding['input_ids'].squeeze(0)
        attention_mask = encoding['attention_mask'].squeeze(0)
        
        # Create shifted labels for LM training (shift left by one)
        # For example: if input_ids = [a, b, c, eos] then labels = [b, c, eos, eos]
        labels = input_ids.clone()
        if input_ids.size(0) > 1:
            labels[:-1] = input_ids[1:]
        # (Optionally, you can set the last token’s label to -100 so that loss is not computed)
        # labels[-1] = -100
        
        return image_tensor, combined_patches, input_ids, attention_mask, labels

# --- Define a collate function to pad variable-length sequences ---
def collate_fn(batch):
    images, patches, input_ids_list, attn_masks, labels_list = zip(*batch)
    images = torch.stack(images, dim=0)
    patches = torch.stack(patches, dim=0)
    
    # Use the tokenizer's pad method to pad the input_ids and labels
    batch_encoding = tokenizer.pad(
        {'input_ids': list(input_ids_list)},
        padding='longest',
        return_tensors='pt'
    )
    input_ids = batch_encoding['input_ids']
    
    batch_encoding_labels = tokenizer.pad(
        {'input_ids': list(labels_list)},
        padding='longest',
        return_tensors='pt'
    )
    labels = batch_encoding_labels['input_ids']
    
    # Pad attention masks similarly
    batch_attn = tokenizer.pad(
        {'input_ids': list(attn_masks)},
        padding='longest',
        return_tensors='pt'
    )
    attention_mask = batch_attn['input_ids']
    
    # Replace pad token positions in labels with -100 so that loss is not computed on them
    labels[ input_ids == tokenizer.pad_token_id ] = -100
    
    return images, patches, input_ids, attention_mask, labels

# --- Create Dataset and DataLoaders ---
full_dataset = ReportDataset(augmented_data, train_tfms, patch_tfms, tokenizer, max_length=128)

# You can split into train/val/test (here we use 80/10/10 split)
total_size = len(full_dataset)
train_size = int(0.8 * total_size)
val_size = int(0.1 * total_size)
test_size = total_size - train_size - val_size

train_dataset, val_dataset, test_dataset = random_split(full_dataset, [train_size, val_size, test_size])
batch_size = 16  # adjust as needed

train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, collate_fn=collate_fn)
val_loader   = DataLoader(val_dataset,   batch_size=batch_size, shuffle=False, collate_fn=collate_fn)
test_loader  = DataLoader(test_dataset,  batch_size=batch_size, shuffle=False, collate_fn=collate_fn)

# --- Use your model definition (TwoBranchWithGPT2) unchanged ---
# (Below is your unchanged model code; make sure it is defined exactly as you need)
class TwoBranchWithGPT2(nn.Module):
    def __init__(self, pretrained=True):
        super(TwoBranchWithGPT2, self).__init__()
        # -- SWIN Models --
        self.swin_global = create_model('swin_tiny_patch4_window7_224', pretrained=pretrained)
        self.swin_global.head = nn.Identity()
        self.swin_patch = create_model('swin_tiny_patch4_window7_224', pretrained=pretrained)
        self.swin_patch.head = nn.Identity()

        # -- ResNet Models --
        from torchvision import models
        resnet_global = models.resnet50(pretrained=pretrained)
        resnet_patch  = models.resnet50(pretrained=pretrained)
        resnet_global.fc = nn.Identity()
        resnet_patch.fc  = nn.Identity()
        self.resnet_global = resnet_global
        self.resnet_patch  = resnet_patch

        # -- Convert patch channels from 102 to 3 --
        self.patch_channel_reduction = nn.Conv2d(in_channels=102, out_channels=3, kernel_size=1)

        # -- Merge & project features to GPT2 hidden size (768) --
        self.feature_attention = nn.Sequential(
            nn.Linear(5632, 768),
            nn.ReLU()
        )

        # -- Classification head (unused in caption training) --
        self.classifier = nn.Sequential(
            nn.Linear(768, 512),
            nn.ReLU(),
            nn.Linear(512, 1),
            nn.Sigmoid()
        )

        # -- GPT-2 with cross-attention --
        gpt2_config = GPT2Config.from_pretrained("gpt2", add_cross_attention=True)
        self.gpt2 = GPT2LMHeadModel.from_pretrained("gpt2", config=gpt2_config)
        self.gpt2.resize_token_embeddings(len(tokenizer))
        # Initialize cross-attention layers
        for name, param in self.gpt2.named_parameters():
            if 'crossattention' in name:
                param.data.normal_(mean=0.0, std=0.02)

        # -- Prefix Projector --
        self.prefix_length = 10  # can experiment with this
        self.prefix_projector = nn.Linear(768, 768 * self.prefix_length)

    def encode_features(self, images, patches):
        # Resize patches to image size and reduce channels
        patches_resized = F.interpolate(patches, size=(224, 224), mode='bilinear', align_corners=False)
        patches_reduced = self.patch_channel_reduction(patches_resized)

        swin_global_features = self.swin_global.forward_features(images).mean(dim=[1, 2])
        swin_patch_features = self.swin_patch.forward_features(patches_reduced).mean(dim=[1, 2])
        resnet_global_features = self.resnet_global(images)
        resnet_patch_features = self.resnet_patch(patches_reduced)

        combined_features = torch.cat([swin_global_features, swin_patch_features,
                                         resnet_global_features, resnet_patch_features], dim=1)
        projected_features = self.feature_attention(combined_features)
        projected_features = F.normalize(projected_features, dim=-1)
        return projected_features

    def forward(self, images, patches, input_ids=None, attention_mask=None):
        projected_features = self.encode_features(images, patches)
        # We still compute classification output (unused in caption LM training)
        cls_output = self.classifier(projected_features)

        if input_ids is not None:
            batch_size = projected_features.size(0)
            encoder_hidden_states = self.prefix_projector(projected_features)
            encoder_hidden_states = encoder_hidden_states.view(batch_size, self.prefix_length, 768)
            # Teacher forcing mode: pass LM input tokens along with cross-attention
            gpt_outputs = self.gpt2(
                input_ids=input_ids,
                attention_mask=attention_mask,
                encoder_hidden_states=encoder_hidden_states
            )
            return cls_output, gpt_outputs.logits
        else:
            return cls_output

    def generate_reports(self, projected_features, tokenizer, max_length=50, 
                           do_sample=True, top_k=50, top_p=0.95, temperature=0.7):
        device = projected_features.device
        batch_size = projected_features.size(0)
        bos_id = tokenizer.bos_token_id
        eos_id = tokenizer.eos_token_id

        projected_features = F.normalize(projected_features, dim=-1)
        encoder_hidden_states = self.prefix_projector(projected_features)
        encoder_hidden_states = encoder_hidden_states.view(batch_size, self.prefix_length, 768)

        input_ids = torch.full((batch_size, 1), bos_id, device=device, dtype=torch.long)
        encoder_attention_mask = torch.ones((batch_size, self.prefix_length), device=device, dtype=torch.long)

        generated_ids = self.gpt2.generate(
            input_ids=input_ids,
            encoder_hidden_states=encoder_hidden_states,
            encoder_attention_mask=encoder_attention_mask,
            max_length=max_length,
            do_sample=do_sample,
            top_k=top_k,
            top_p=top_p,
            temperature=temperature,
            repetition_penalty=1.2,
            bos_token_id=bos_id,
            eos_token_id=eos_id,
            pad_token_id=tokenizer.pad_token_id
        )
        generated_texts = [tokenizer.decode(seq.tolist(), skip_special_tokens=True) for seq in generated_ids]
        return generated_texts

# --- Set up device, model, loss, optimizer, and scheduler ---
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = TwoBranchWithGPT2(pretrained=True).to(device)

# We use the text generation loss only (CrossEntropyLoss)
criterion_txt = nn.CrossEntropyLoss(ignore_index=-100)
optimizer = optim.AdamW(model.parameters(), lr=1e-4, weight_decay=1e-5)
from torch.optim.lr_scheduler import CosineAnnealingWarmRestarts
scheduler = CosineAnnealingWarmRestarts(optimizer, T_0=10, T_mult=2)

# --- Training Loop ---
num_epochs = 1  # adjust as needed
for epoch in range(num_epochs):
    model.train()
    running_loss = 0.0
    pbar = tqdm(train_loader, desc=f"Epoch {epoch+1}/{num_epochs}")
    for images, patches, input_ids, attention_mask, labels in pbar:
        images = images.to(device)
        patches = patches.to(device)
        input_ids = input_ids.to(device)
        attention_mask = attention_mask.to(device)
        labels = labels.to(device)

        optimizer.zero_grad()
        # We call the model in teacher-forcing mode.
        _, gpt_logits = model(images, patches, input_ids, attention_mask)
        # Flatten logits and labels for loss computation.
        loss = criterion_txt(gpt_logits.reshape(-1, gpt_logits.size(-1)), labels.reshape(-1))
        loss.backward()
        optimizer.step()
        scheduler.step()

        running_loss += loss.item()
        pbar.set_postfix({"loss": f"{loss.item():.4f}"})
    avg_loss = running_loss / len(train_loader)
    print(f"Epoch {epoch+1} Training Loss: {avg_loss:.4f}")

    # Validation step (optional)
    model.eval()
    val_loss = 0.0
    with torch.no_grad():
        for images, patches, input_ids, attention_mask, labels in val_loader:
            images = images.to(device)
            patches = patches.to(device)
            input_ids = input_ids.to(device)
            attention_mask = attention_mask.to(device)
            labels = labels.to(device)

            _, gpt_logits = model(images, patches, input_ids, attention_mask)
            loss = criterion_txt(gpt_logits.reshape(-1, gpt_logits.size(-1)), labels.reshape(-1))
            val_loss += loss.item()
    avg_val_loss = val_loss / len(val_loader)
    print(f"Epoch {epoch+1} Validation Loss: {avg_val_loss:.4f}")

print("Training Complete!")

# --- Testing / Inference: Generate Reports ---
model.eval()
all_generated_texts = []
with torch.no_grad():
    for images, patches, input_ids, attention_mask, labels in test_loader:
        images = images.to(device)
        patches = patches.to(device)
        projected_features = model.encode_features(images, patches)
        gen_texts = model.generate_reports(projected_features, tokenizer, max_length=60)
        all_generated_texts.extend(gen_texts)

# Print a few sample generated reports
print("\n=== Sample Generated Reports ===\n")
for i, report in enumerate(all_generated_texts[:3]):
    print(f"[Generated Report {i+1}]:\n{report}")
    print("--------------------------------------------------")


2025-03-13 17:27:00.946291: I tensorflow/core/util/port.cc:153] oneDNN custom operations are on. You may see slightly different numerical results due to floating-point round-off errors from different computation orders. To turn them off, set the environment variable `TF_ENABLE_ONEDNN_OPTS=0`.
2025-03-13 17:27:00.951949: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:485] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
2025-03-13 17:27:00.958791: E external/local_xla/xla/stream_executor/cuda/cuda_dnn.cc:8454] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
2025-03-13 17:27:00.960861: E external/local_xla/xla/stream_executor/cuda/cuda_blas.cc:1452] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered
2025-03-13 17:27:00.966050: I tensorflow/core/platform/cpu_feature_guar

Epoch 1 Training Loss: 2.2289
Epoch 1 Validation Loss: 1.2166
Training Complete!


The attention mask is not set and cannot be inferred from input because pad token is same as eos token. As a consequence, you may observe unexpected behavior. Please pass your input's `attention_mask` to obtain reliable results.
A decoder-only architecture is being used, but right-padding was detected! For correct generation results, please set `padding_side='left'` when initializing the tokenizer.
A decoder-only architecture is being used, but right-padding was detected! For correct generation results, please set `padding_side='left'` when initializing the tokenizer.
A decoder-only architecture is being used, but right-padding was detected! For correct generation results, please set `padding_side='left'` when initializing the tokenizer.



=== Sample Generated Reports ===

[Generated Report 1]:
FINDING 
a) Rheumatoid arthritis (RA). b. RA involvement, both hands and feet: no significant difference at all on radiographs or MRI scans of joint space narrowing 2 years ago - 5th MTPR examination 4/5 1st 3rd TMT
--------------------------------------------------
[Generated Report 2]:
The T-bone joint in both hands.
Tibial osteopenia, and no significant change since the last study (possible RA involvement).  
 CONCLUSION - In addition to Rt 3rd MTP joints with moderate reduction of 5th RD3r5mTR
--------------------------------------------------
[Generated Report 3]:
FINDING 
Cortical flexion, both. - no significant changes at T1st MTP joint and Rt 1st MTJT joints > 2nd RA involvement but with PPPOLEON formation (no change) < 5th DLR-TRG correlation
--------------------------------------------------


In [2]:
import os, pickle, random
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader, random_split
from tqdm import tqdm

# --- Import your model components and tokenizer ---
from timm import create_model
from transformers import GPT2TokenizerFast, GPT2LMHeadModel, GPT2Config

# --- Use Albumentations for image augmentation (as in your reference code) ---
import albumentations as A
from albumentations.pytorch import ToTensorV2
from PIL import Image

# --- Set up the tokenizer ---
tokenizer = GPT2TokenizerFast.from_pretrained("gpt2")
# GPT-2 does not have a pad token so we use EOS
tokenizer.pad_token = tokenizer.eos_token

# --- Load your data (pkl file) ---
with open('updated_merge_json_200x300.pkl', 'rb') as f:
    data = pickle.load(f)

# --- Balance the dataset ---
task_classes = ['ra', 'normal']
class_counts = {cls: 0 for cls in task_classes}
data_by_class = {cls: [] for cls in task_classes}

for key, entry in data.items():
    # use lower-case class label and only include if image file exists
    class_label = entry.get('class', '').lower()
    if class_label in class_counts and os.path.exists(entry['file_path']):
        class_counts[class_label] += 1
        data_by_class[class_label].append(entry)

min_class_count = min(class_counts.values())
balanced_data = []
for cls in task_classes:
    balanced_data.extend(random.sample(data_by_class[cls], min_class_count))

# Optionally, you may multiply the balanced_data for augmentation purposes:
augmented_data = balanced_data * 2

# --- Define Albumentations transforms ---
train_tfms = A.Compose([
    A.Resize(224, 224),
    A.HorizontalFlip(p=0.5),
    A.RandomBrightnessContrast(p=0.5),
    A.Normalize(mean=[0.5,0.5,0.5], std=[0.5,0.5,0.5], always_apply=True),
    ToTensorV2()
])

valid_tfms = A.Compose([
    A.Resize(224, 224),
    A.Normalize(mean=[0.5,0.5,0.5], std=[0.5,0.5,0.5], always_apply=True),
    ToTensorV2()
])

# For patch images we use a similar transform (you can adjust if needed)
patch_tfms = A.Compose([
    A.Resize(112, 112),
    A.Normalize(mean=[0.5,0.5,0.5], std=[0.5,0.5,0.5], always_apply=True),
    ToTensorV2()
])

# --- Define a new Dataset for report generation ---
class ReportDataset(Dataset):
    def __init__(self, data, img_tfms, patch_tfms, tokenizer, max_length=128):
        self.data = data
        self.img_tfms = img_tfms
        self.patch_tfms = patch_tfms
        self.tokenizer = tokenizer
        self.max_length = max_length

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        entry = self.data[idx]
        
        # Load the main image
        image = Image.open(entry['file_path']).convert('RGB')
        image = np.array(image)
        augmented = self.img_tfms(image=image)
        image_tensor = augmented['image']
        
        # Process patch images (if any)
        patches = entry.get('bbx', [])
        if len(patches) > 0:
            # Limit to up to 34 patches (as before)
            patch_imgs = [Image.fromarray(patch) for patch in patches[:34]]
            patch_tensors = []
            for p in patch_imgs:
                p = np.array(p)
                aug_patch = self.patch_tfms(image=p)
                patch_tensors.append(aug_patch['image'])
            # Concatenate along the channel dimension: resulting shape (34*3, 112, 112)
            combined_patches = torch.cat(patch_tensors, dim=0)
        else:
            combined_patches = torch.zeros(34*3, 112, 112)
        
        # Process the report text (diagnosis field)
        # Replace unwanted tokens and append EOS token (using tokenizer.eos_token)
        report_text = entry['diagnosis'].replace('_x000D_', ' ').strip()
        caption = f"{report_text}{self.tokenizer.eos_token}"
        
        # Tokenize without padding here (we will pad in the collate function)
        encoding = self.tokenizer(
            caption,
            truncation=True,
            max_length=self.max_length,
            return_attention_mask=True,
            return_tensors="pt"
        )
        # Squeeze to remove batch dimension (now shape: [seq_len])
        input_ids = encoding['input_ids'].squeeze(0)
        attention_mask = encoding['attention_mask'].squeeze(0)
        
        # Create shifted labels for LM training (shift left by one)
        # For example: if input_ids = [a, b, c, eos] then labels = [b, c, eos, eos]
        labels = input_ids.clone()
        if input_ids.size(0) > 1:
            labels[:-1] = input_ids[1:]
        # (Optionally, you can set the last token’s label to -100 so that loss is not computed)
        # labels[-1] = -100
        
        return image_tensor, combined_patches, input_ids, attention_mask, labels

# --- Define a collate function to pad variable-length sequences ---
def collate_fn(batch):
    images, patches, input_ids_list, attn_masks, labels_list = zip(*batch)
    images = torch.stack(images, dim=0)
    patches = torch.stack(patches, dim=0)
    
    # Use the tokenizer's pad method to pad the input_ids and labels
    batch_encoding = tokenizer.pad(
        {'input_ids': list(input_ids_list)},
        padding='longest',
        return_tensors='pt'
    )
    input_ids = batch_encoding['input_ids']
    
    batch_encoding_labels = tokenizer.pad(
        {'input_ids': list(labels_list)},
        padding='longest',
        return_tensors='pt'
    )
    labels = batch_encoding_labels['input_ids']
    
    # Pad attention masks similarly
    batch_attn = tokenizer.pad(
        {'input_ids': list(attn_masks)},
        padding='longest',
        return_tensors='pt'
    )
    attention_mask = batch_attn['input_ids']
    
    # Replace pad token positions in labels with -100 so that loss is not computed on them
    labels[ input_ids == tokenizer.pad_token_id ] = -100
    
    return images, patches, input_ids, attention_mask, labels

# --- Create Dataset and DataLoaders ---
full_dataset = ReportDataset(augmented_data, train_tfms, patch_tfms, tokenizer, max_length=128)

# You can split into train/val/test (here we use 80/10/10 split)
total_size = len(full_dataset)
train_size = int(0.8 * total_size)
val_size = int(0.1 * total_size)
test_size = total_size - train_size - val_size

train_dataset, val_dataset, test_dataset = random_split(full_dataset, [train_size, val_size, test_size])
batch_size = 16  # adjust as needed

train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, collate_fn=collate_fn)
val_loader   = DataLoader(val_dataset,   batch_size=batch_size, shuffle=False, collate_fn=collate_fn)
test_loader  = DataLoader(test_dataset,  batch_size=batch_size, shuffle=False, collate_fn=collate_fn)

# --- Use your model definition (TwoBranchWithGPT2) unchanged ---
# (Below is your unchanged model code; make sure it is defined exactly as you need)
class TwoBranchWithGPT2(nn.Module):
    def __init__(self, pretrained=True):
        super(TwoBranchWithGPT2, self).__init__()
        # -- SWIN Models --
        self.swin_global = create_model('swin_tiny_patch4_window7_224', pretrained=pretrained)
        self.swin_global.head = nn.Identity()
        self.swin_patch = create_model('swin_tiny_patch4_window7_224', pretrained=pretrained)
        self.swin_patch.head = nn.Identity()

        # -- ResNet Models --
        from torchvision import models
        resnet_global = models.resnet50(pretrained=pretrained)
        resnet_patch  = models.resnet50(pretrained=pretrained)
        resnet_global.fc = nn.Identity()
        resnet_patch.fc  = nn.Identity()
        self.resnet_global = resnet_global
        self.resnet_patch  = resnet_patch

        # -- Convert patch channels from 102 to 3 --
        self.patch_channel_reduction = nn.Conv2d(in_channels=102, out_channels=3, kernel_size=1)

        # -- Merge & project features to GPT2 hidden size (768) --
        self.feature_attention = nn.Sequential(
            nn.Linear(5632, 768),
            nn.ReLU()
        )

        # -- Classification head (unused in caption training) --
        self.classifier = nn.Sequential(
            nn.Linear(768, 512),
            nn.ReLU(),
            nn.Linear(512, 1),
            nn.Sigmoid()
        )

        # -- GPT-2 with cross-attention --
        gpt2_config = GPT2Config.from_pretrained("gpt2", add_cross_attention=True)
        self.gpt2 = GPT2LMHeadModel.from_pretrained("gpt2", config=gpt2_config)
        self.gpt2.resize_token_embeddings(len(tokenizer))
        # Initialize cross-attention layers
        for name, param in self.gpt2.named_parameters():
            if 'crossattention' in name:
                param.data.normal_(mean=0.0, std=0.02)

        # -- Prefix Projector --
        self.prefix_length = 10  # can experiment with this
        self.prefix_projector = nn.Linear(768, 768 * self.prefix_length)

    def encode_features(self, images, patches):
        # Resize patches to image size and reduce channels
        patches_resized = F.interpolate(patches, size=(224, 224), mode='bilinear', align_corners=False)
        patches_reduced = self.patch_channel_reduction(patches_resized)

        swin_global_features = self.swin_global.forward_features(images).mean(dim=[1, 2])
        swin_patch_features = self.swin_patch.forward_features(patches_reduced).mean(dim=[1, 2])
        resnet_global_features = self.resnet_global(images)
        resnet_patch_features = self.resnet_patch(patches_reduced)

        combined_features = torch.cat([swin_global_features, swin_patch_features,
                                         resnet_global_features, resnet_patch_features], dim=1)
        projected_features = self.feature_attention(combined_features)
        projected_features = F.normalize(projected_features, dim=-1)
        return projected_features

    def forward(self, images, patches, input_ids=None, attention_mask=None):
        projected_features = self.encode_features(images, patches)
        # We still compute classification output (unused in caption LM training)
        cls_output = self.classifier(projected_features)

        if input_ids is not None:
            batch_size = projected_features.size(0)
            encoder_hidden_states = self.prefix_projector(projected_features)
            encoder_hidden_states = encoder_hidden_states.view(batch_size, self.prefix_length, 768)
            # Teacher forcing mode: pass LM input tokens along with cross-attention
            gpt_outputs = self.gpt2(
                input_ids=input_ids,
                attention_mask=attention_mask,
                encoder_hidden_states=encoder_hidden_states
            )
            return cls_output, gpt_outputs.logits
        else:
            return cls_output

    def generate_reports(self, projected_features, tokenizer, max_length=50, 
                           do_sample=True, top_k=50, top_p=0.95, temperature=0.7):
        device = projected_features.device
        batch_size = projected_features.size(0)
        bos_id = tokenizer.bos_token_id
        eos_id = tokenizer.eos_token_id

        projected_features = F.normalize(projected_features, dim=-1)
        encoder_hidden_states = self.prefix_projector(projected_features)
        encoder_hidden_states = encoder_hidden_states.view(batch_size, self.prefix_length, 768)

        input_ids = torch.full((batch_size, 1), bos_id, device=device, dtype=torch.long)
        encoder_attention_mask = torch.ones((batch_size, self.prefix_length), device=device, dtype=torch.long)

        generated_ids = self.gpt2.generate(
            input_ids=input_ids,
            encoder_hidden_states=encoder_hidden_states,
            encoder_attention_mask=encoder_attention_mask,
            max_length=max_length,
            do_sample=do_sample,
            top_k=top_k,
            top_p=top_p,
            temperature=temperature,
            repetition_penalty=1.2,
            bos_token_id=bos_id,
            eos_token_id=eos_id,
            pad_token_id=tokenizer.pad_token_id
        )
        generated_texts = [tokenizer.decode(seq.tolist(), skip_special_tokens=True) for seq in generated_ids]
        return generated_texts

# --- Set up device, model, loss, optimizer, and scheduler ---
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = TwoBranchWithGPT2(pretrained=True).to(device)

# We use the text generation loss only (CrossEntropyLoss)
criterion_txt = nn.CrossEntropyLoss(ignore_index=-100)
optimizer = optim.AdamW(model.parameters(), lr=1e-4, weight_decay=1e-5)
from torch.optim.lr_scheduler import CosineAnnealingWarmRestarts
scheduler = CosineAnnealingWarmRestarts(optimizer, T_0=10, T_mult=2)

# --- Training Loop ---
num_epochs = 1  # adjust as needed
for epoch in range(num_epochs):
    model.train()
    running_loss = 0.0
    pbar = tqdm(train_loader, desc=f"Epoch {epoch+1}/{num_epochs}")
    for images, patches, input_ids, attention_mask, labels in pbar:
        images = images.to(device)
        patches = patches.to(device)
        input_ids = input_ids.to(device)
        attention_mask = attention_mask.to(device)
        labels = labels.to(device)

        optimizer.zero_grad()
        # We call the model in teacher-forcing mode.
        _, gpt_logits = model(images, patches, input_ids, attention_mask)
        # Flatten logits and labels for loss computation.
        loss = criterion_txt(gpt_logits.reshape(-1, gpt_logits.size(-1)), labels.reshape(-1))
        loss.backward()
        optimizer.step()
        scheduler.step()

        running_loss += loss.item()
        pbar.set_postfix({"loss": f"{loss.item():.4f}"})
    avg_loss = running_loss / len(train_loader)
    print(f"Epoch {epoch+1} Training Loss: {avg_loss:.4f}")

    # Validation step (optional)
    model.eval()
    val_loss = 0.0
    with torch.no_grad():
        for images, patches, input_ids, attention_mask, labels in val_loader:
            images = images.to(device)
            patches = patches.to(device)
            input_ids = input_ids.to(device)
            attention_mask = attention_mask.to(device)
            labels = labels.to(device)

            _, gpt_logits = model(images, patches, input_ids, attention_mask)
            loss = criterion_txt(gpt_logits.reshape(-1, gpt_logits.size(-1)), labels.reshape(-1))
            val_loss += loss.item()
    avg_val_loss = val_loss / len(val_loader)
    print(f"Epoch {epoch+1} Validation Loss: {avg_val_loss:.4f}")

print("Training Complete!")

# --- Testing / Inference: Generate Reports and Print Ground Truth ---
model.eval()
all_generated_texts = []
all_ground_truth_texts = []

with torch.no_grad():
    for images, patches, input_ids, attention_mask, labels in test_loader:
        images = images.to(device)
        patches = patches.to(device)
        # Get the encoded image features for caption generation.
        projected_features = model.encode_features(images, patches)
        # Generate reports for the current batch.
        gen_texts = model.generate_reports(projected_features, tokenizer, max_length=60)
        all_generated_texts.extend(gen_texts)
        
        # Decode ground truth reports from input_ids.
        # Note: We decode the first token sequence in each sample (removing padding tokens).
        for seq in input_ids:
            gt_text = tokenizer.decode(seq.tolist(), skip_special_tokens=True)
            all_ground_truth_texts.append(gt_text)

# Print a few sample generated reports along with the ground truth
print("\n=== Sample Generated Reports and Ground Truth Reports ===\n")
num_samples_to_show = 3
for i in range(num_samples_to_show):
    print(f"[Ground Truth Report {i+1}]:\n{all_ground_truth_texts[i]}\n")
    print(f"[Generated Report {i+1}]:\n{all_generated_texts[i]}\n")
    print("--------------------------------------------------")



  A.Normalize(mean=[0.5,0.5,0.5], std=[0.5,0.5,0.5], always_apply=True),
  A.Normalize(mean=[0.5,0.5,0.5], std=[0.5,0.5,0.5], always_apply=True),
  A.Normalize(mean=[0.5,0.5,0.5], std=[0.5,0.5,0.5], always_apply=True),
Some weights of GPT2LMHeadModel were not initialized from the model checkpoint at gpt2 and are newly initialized: ['h.0.crossattention.c_attn.bias', 'h.0.crossattention.c_attn.weight', 'h.0.crossattention.c_proj.bias', 'h.0.crossattention.c_proj.weight', 'h.0.crossattention.q_attn.bias', 'h.0.crossattention.q_attn.weight', 'h.0.ln_cross_attn.bias', 'h.0.ln_cross_attn.weight', 'h.1.crossattention.c_attn.bias', 'h.1.crossattention.c_attn.weight', 'h.1.crossattention.c_proj.bias', 'h.1.crossattention.c_proj.weight', 'h.1.crossattention.q_attn.bias', 'h.1.crossattention.q_attn.weight', 'h.1.ln_cross_attn.bias', 'h.1.ln_cross_attn.weight', 'h.10.crossattention.c_attn.bias', 'h.10.crossattention.c_attn.weight', 'h.10.crossattention.c_proj.bias', 'h.10.crossattention.c_proj.wei

Epoch 1 Training Loss: 2.2164
Epoch 1 Validation Loss: 1.3563
Training Complete!


A decoder-only architecture is being used, but right-padding was detected! For correct generation results, please set `padding_side='left'` when initializing the tokenizer.
A decoder-only architecture is being used, but right-padding was detected! For correct generation results, please set `padding_side='left'` when initializing the tokenizer.
A decoder-only architecture is being used, but right-padding was detected! For correct generation results, please set `padding_side='left'` when initializing the tokenizer.



=== Sample Generated Reports and Ground Truth Reports ===

[Ground Truth Report 1]:
[ Finding ] 
Lt. 3-5th MCP joint, extensor hood, synovitis. 
Lt. 2-5th PIP joints, synovial thickening. 
Rt. 3rd MCP, extensor hood, soft tissue swelling, suggestive of synovitis. 
   ==> RA involvement, suggested. 
 
Rt. 1st IP joint, synovitis and calcification. 
  --> CPPD vs. RA involvement. 
 
Rt. 5th/4th/3rd/2nd MTP joints,

[Generated Report 1]:
FINDING 
A rt. joint, both knees and elbows up to hip bone level at knee's side of head - RA involvement? 1st MTPM ? (OR/R)? 4th RTRT ???) 2nd MAO type 3rd MTKMT K

--------------------------------------------------
[Ground Truth Report 2]:
[ Finding ] 
Hallux valgus. 
[ Diagnosis ] 
 
[ Recommend ]

[Generated Report 2]:
FINDING 
a bony abnormality. It is likely an underlying degenerative change in the frontal, posterior temporal lobes and occipital lobe/cortical joint space due to erosions on both sides of RA axis - possibly Rt 5th MTP junction?

-----

In [10]:
import os, pickle, random
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader, random_split
from tqdm import tqdm

# --- Import your model components and tokenizer ---
from timm import create_model
from transformers import GPT2TokenizerFast, GPT2LMHeadModel, GPT2Config

# --- Use Albumentations for image augmentation (as in your reference code) ---
import albumentations as A
from albumentations.pytorch import ToTensorV2
from PIL import Image

from nltk.translate.bleu_score import sentence_bleu,SmoothingFunction
from rouge import Rouge
import numpy as np
import evaluate

# --- Set up the tokenizer ---
tokenizer = GPT2TokenizerFast.from_pretrained("gpt2")
# GPT-2 does not have a pad token so we use EOS
tokenizer.pad_token = tokenizer.eos_token

# --- Load your data (pkl file) ---
with open('updated_merge_json_200x300.pkl', 'rb') as f:
    data = pickle.load(f)

# --- Balance the dataset ---
task_classes = ['ra', 'normal']
class_counts = {cls: 0 for cls in task_classes}
data_by_class = {cls: [] for cls in task_classes}

for key, entry in data.items():
    # use lower-case class label and only include if image file exists
    class_label = entry.get('class', '').lower()
    if class_label in class_counts and os.path.exists(entry['file_path']):
        class_counts[class_label] += 1
        data_by_class[class_label].append(entry)

min_class_count = min(class_counts.values())
balanced_data = []
for cls in task_classes:
    balanced_data.extend(random.sample(data_by_class[cls], min_class_count))

# Optionally, you may multiply the balanced_data for augmentation purposes:
augmented_data = balanced_data * 2

# --- Define Albumentations transforms ---
train_tfms = A.Compose([
    A.Resize(224, 224),
    A.HorizontalFlip(p=0.5),
    A.RandomBrightnessContrast(p=0.5),
    A.Normalize(mean=[0.5,0.5,0.5], std=[0.5,0.5,0.5], always_apply=True),
    ToTensorV2()
])

valid_tfms = A.Compose([
    A.Resize(224, 224),
    A.Normalize(mean=[0.5,0.5,0.5], std=[0.5,0.5,0.5], always_apply=True),
    ToTensorV2()
])

# For patch images we use a similar transform (you can adjust if needed)
patch_tfms = A.Compose([
    A.Resize(112, 112),
    A.Normalize(mean=[0.5,0.5,0.5], std=[0.5,0.5,0.5], always_apply=True),
    ToTensorV2()
])

# --- Define a new Dataset for report generation ---
class ReportDataset(Dataset):
    def __init__(self, data, img_tfms, patch_tfms, tokenizer, max_length=128):
        self.data = data
        self.img_tfms = img_tfms
        self.patch_tfms = patch_tfms
        self.tokenizer = tokenizer
        self.max_length = max_length

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        entry = self.data[idx]
        
        # Load the main image
        image = Image.open(entry['file_path']).convert('RGB')
        image = np.array(image)
        augmented = self.img_tfms(image=image)
        image_tensor = augmented['image']
        
        # Process patch images (if any)
        patches = entry.get('bbx', [])
        if len(patches) > 0:
            # Limit to up to 34 patches (as before)
            patch_imgs = [Image.fromarray(patch) for patch in patches[:34]]
            patch_tensors = []
            for p in patch_imgs:
                p = np.array(p)
                aug_patch = self.patch_tfms(image=p)
                patch_tensors.append(aug_patch['image'])
            # Concatenate along the channel dimension: resulting shape (34*3, 112, 112)
            combined_patches = torch.cat(patch_tensors, dim=0)
        else:
            combined_patches = torch.zeros(34*3, 112, 112)
        
        # Process the report text (diagnosis field)
        # Replace unwanted tokens and append EOS token (using tokenizer.eos_token)
        report_text = entry['diagnosis'].replace('_x000D_', ' ').strip()
        caption = f"{report_text}{self.tokenizer.eos_token}"
        
        # Tokenize without padding here (we will pad in the collate function)
        encoding = self.tokenizer(
            caption,
            truncation=True,
            max_length=self.max_length,
            return_attention_mask=True,
            return_tensors="pt"
        )
        # Squeeze to remove batch dimension (now shape: [seq_len])
        input_ids = encoding['input_ids'].squeeze(0)
        attention_mask = encoding['attention_mask'].squeeze(0)
        
        # Create shifted labels for LM training (shift left by one)
        # For example: if input_ids = [a, b, c, eos] then labels = [b, c, eos, eos]
        labels = input_ids.clone()
        if input_ids.size(0) > 1:
            labels[:-1] = input_ids[1:]
        # (Optionally, you can set the last token’s label to -100 so that loss is not computed)
        # labels[-1] = -100
        
        return image_tensor, combined_patches, input_ids, attention_mask, labels

# --- Define a collate function to pad variable-length sequences ---
def collate_fn(batch):
    images, patches, input_ids_list, attn_masks, labels_list = zip(*batch)
    images = torch.stack(images, dim=0)
    patches = torch.stack(patches, dim=0)
    
    # Use the tokenizer's pad method to pad the input_ids and labels
    batch_encoding = tokenizer.pad(
        {'input_ids': list(input_ids_list)},
        padding='longest',
        return_tensors='pt'
    )
    input_ids = batch_encoding['input_ids']
    
    batch_encoding_labels = tokenizer.pad(
        {'input_ids': list(labels_list)},
        padding='longest',
        return_tensors='pt'
    )
    labels = batch_encoding_labels['input_ids']
    
    # Pad attention masks similarly
    batch_attn = tokenizer.pad(
        {'input_ids': list(attn_masks)},
        padding='longest',
        return_tensors='pt'
    )
    attention_mask = batch_attn['input_ids']
    
    # Replace pad token positions in labels with -100 so that loss is not computed on them
    labels[ input_ids == tokenizer.pad_token_id ] = -100
    
    return images, patches, input_ids, attention_mask, labels

# --- Create Dataset and DataLoaders ---
full_dataset = ReportDataset(augmented_data, train_tfms, patch_tfms, tokenizer, max_length=128)

# You can split into train/val/test (here we use 80/10/10 split)
total_size = len(full_dataset)
train_size = int(0.8 * total_size)
val_size = int(0.1 * total_size)
test_size = total_size - train_size - val_size

train_dataset, val_dataset, test_dataset = random_split(full_dataset, [train_size, val_size, test_size])
batch_size = 16  # adjust as needed

train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, collate_fn=collate_fn)
val_loader   = DataLoader(val_dataset,   batch_size=batch_size, shuffle=False, collate_fn=collate_fn)
test_loader  = DataLoader(test_dataset,  batch_size=batch_size, shuffle=False, collate_fn=collate_fn)

# --- Use your model definition (TwoBranchWithGPT2) unchanged ---
# (Below is your unchanged model code; make sure it is defined exactly as you need)
class TwoBranchWithGPT2(nn.Module):
    def __init__(self, pretrained=True):
        super(TwoBranchWithGPT2, self).__init__()
        # -- SWIN Models --
        self.swin_global = create_model('swin_tiny_patch4_window7_224', pretrained=pretrained)
        self.swin_global.head = nn.Identity()
        self.swin_patch = create_model('swin_tiny_patch4_window7_224', pretrained=pretrained)
        self.swin_patch.head = nn.Identity()

        # -- ResNet Models --
        from torchvision import models
        resnet_global = models.resnet50(pretrained=pretrained)
        resnet_patch  = models.resnet50(pretrained=pretrained)
        resnet_global.fc = nn.Identity()
        resnet_patch.fc  = nn.Identity()
        self.resnet_global = resnet_global
        self.resnet_patch  = resnet_patch

        # -- Convert patch channels from 102 to 3 --
        self.patch_channel_reduction = nn.Conv2d(in_channels=102, out_channels=3, kernel_size=1)

        # -- Merge & project features to GPT2 hidden size (768) --
        self.feature_attention = nn.Sequential(
            nn.Linear(5632, 768),
            nn.ReLU()
        )

        # -- Classification head (unused in caption training) --
        self.classifier = nn.Sequential(
            nn.Linear(768, 512),
            nn.ReLU(),
            nn.Linear(512, 1),
            nn.Sigmoid()
        )

        # -- GPT-2 with cross-attention --
        gpt2_config = GPT2Config.from_pretrained("gpt2", add_cross_attention=True)
        self.gpt2 = GPT2LMHeadModel.from_pretrained("gpt2", config=gpt2_config)
        self.gpt2.resize_token_embeddings(len(tokenizer))
        # Initialize cross-attention layers
        for name, param in self.gpt2.named_parameters():
            if 'crossattention' in name:
                param.data.normal_(mean=0.0, std=0.02)

        # -- Prefix Projector --
        self.prefix_length = 10  # can experiment with this
        self.prefix_projector = nn.Linear(768, 768 * self.prefix_length)

    def encode_features(self, images, patches):
        # Resize patches to image size and reduce channels
        patches_resized = F.interpolate(patches, size=(224, 224), mode='bilinear', align_corners=False)
        patches_reduced = self.patch_channel_reduction(patches_resized)

        swin_global_features = self.swin_global.forward_features(images).mean(dim=[1, 2])
        swin_patch_features = self.swin_patch.forward_features(patches_reduced).mean(dim=[1, 2])
        resnet_global_features = self.resnet_global(images)
        resnet_patch_features = self.resnet_patch(patches_reduced)

        combined_features = torch.cat([swin_global_features, swin_patch_features,
                                         resnet_global_features, resnet_patch_features], dim=1)
        projected_features = self.feature_attention(combined_features)
        projected_features = F.normalize(projected_features, dim=-1)
        return projected_features

    def forward(self, images, patches, input_ids=None, attention_mask=None):
        projected_features = self.encode_features(images, patches)
        # We still compute classification output (unused in caption LM training)
        cls_output = self.classifier(projected_features)

        if input_ids is not None:
            batch_size = projected_features.size(0)
            encoder_hidden_states = self.prefix_projector(projected_features)
            encoder_hidden_states = encoder_hidden_states.view(batch_size, self.prefix_length, 768)
            # Teacher forcing mode: pass LM input tokens along with cross-attention
            gpt_outputs = self.gpt2(
                input_ids=input_ids,
                attention_mask=attention_mask,
                encoder_hidden_states=encoder_hidden_states
            )
            return cls_output, gpt_outputs.logits
        else:
            return cls_output

    def generate_reports(self, projected_features, tokenizer, max_length=50, 
                           do_sample=True, top_k=50, top_p=0.95, temperature=0.7):
        device = projected_features.device
        batch_size = projected_features.size(0)
        bos_id = tokenizer.bos_token_id
        eos_id = tokenizer.eos_token_id

        projected_features = F.normalize(projected_features, dim=-1)
        encoder_hidden_states = self.prefix_projector(projected_features)
        encoder_hidden_states = encoder_hidden_states.view(batch_size, self.prefix_length, 768)

        input_ids = torch.full((batch_size, 1), bos_id, device=device, dtype=torch.long)
        encoder_attention_mask = torch.ones((batch_size, self.prefix_length), device=device, dtype=torch.long)

        generated_ids = self.gpt2.generate(
            input_ids=input_ids,
            encoder_hidden_states=encoder_hidden_states,
            encoder_attention_mask=encoder_attention_mask,
            max_length=max_length,
            do_sample=do_sample,
            top_k=top_k,
            top_p=top_p,
            temperature=temperature,
            repetition_penalty=1.2,
            bos_token_id=bos_id,
            eos_token_id=eos_id,
            pad_token_id=tokenizer.pad_token_id
        )
        generated_texts = [tokenizer.decode(seq.tolist(), skip_special_tokens=True) for seq in generated_ids]
        return generated_texts

# --- Set up device, model, loss, optimizer, and scheduler ---
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = TwoBranchWithGPT2(pretrained=True).to(device)

# We use the text generation loss only (CrossEntropyLoss)
criterion_txt = nn.CrossEntropyLoss(ignore_index=-100)
optimizer = optim.AdamW(model.parameters(), lr=1e-4, weight_decay=1e-5)
from torch.optim.lr_scheduler import CosineAnnealingWarmRestarts
scheduler = CosineAnnealingWarmRestarts(optimizer, T_0=10, T_mult=2)

# --- Training Loop ---
num_epochs = 10  # adjust as needed
for epoch in range(num_epochs):
    model.train()
    running_loss = 0.0
    pbar = tqdm(train_loader, desc=f"Epoch {epoch+1}/{num_epochs}")
    for images, patches, input_ids, attention_mask, labels in pbar:
        images = images.to(device)
        patches = patches.to(device)
        input_ids = input_ids.to(device)
        attention_mask = attention_mask.to(device)
        labels = labels.to(device)

        optimizer.zero_grad()
        # We call the model in teacher-forcing mode.
        _, gpt_logits = model(images, patches, input_ids, attention_mask)
        # Flatten logits and labels for loss computation.
        loss = criterion_txt(gpt_logits.reshape(-1, gpt_logits.size(-1)), labels.reshape(-1))
        loss.backward()
        optimizer.step()
        scheduler.step()

        running_loss += loss.item()
        pbar.set_postfix({"loss": f"{loss.item():.4f}"})
    avg_loss = running_loss / len(train_loader)
    print(f"Epoch {epoch+1} Training Loss: {avg_loss:.4f}")

    # Validation step (optional)
    model.eval()
    val_loss = 0.0
    with torch.no_grad():
        for images, patches, input_ids, attention_mask, labels in val_loader:
            images = images.to(device)
            patches = patches.to(device)
            input_ids = input_ids.to(device)
            attention_mask = attention_mask.to(device)
            labels = labels.to(device)

            _, gpt_logits = model(images, patches, input_ids, attention_mask)
            loss = criterion_txt(gpt_logits.reshape(-1, gpt_logits.size(-1)), labels.reshape(-1))
            val_loss += loss.item()
    avg_val_loss = val_loss / len(val_loader)
    print(f"Epoch {epoch+1} Validation Loss: {avg_val_loss:.4f}")

print("Training Complete!")

# --- Testing / Inference: Generate Reports and Print Ground Truth ---
model.eval()
all_generated_texts = []
all_ground_truth_texts = []

with torch.no_grad():
    for images, patches, input_ids, attention_mask, labels in test_loader:
        images = images.to(device)
        patches = patches.to(device)
        # Get the encoded image features for caption generation.
        projected_features = model.encode_features(images, patches)
        # Generate reports for the current batch.
        gen_texts = model.generate_reports(projected_features, tokenizer, max_length=60)
        all_generated_texts.extend(gen_texts)
        
        # Decode ground truth reports from input_ids.
        # Note: We decode the first token sequence in each sample (removing padding tokens).
        for seq in input_ids:
            gt_text = tokenizer.decode(seq.tolist(), skip_special_tokens=True)
            all_ground_truth_texts.append(gt_text)

# Print a few sample generated reports along with the ground truth
print("\n=== Sample Generated Reports and Ground Truth Reports ===\n")
num_samples_to_show = 5
for i in range(num_samples_to_show):
    print(f"[Ground Truth Report {i+1}]:\n{all_ground_truth_texts[i]}\n")
    print(f"[Generated Report {i+1}]:\n{all_generated_texts[i]}\n")
    print("--------------------------------------------------")

def evaluate_generated_texts(ground_truth_texts, generated_texts):
    """
    Evaluate generated texts using BLEU (with smoothing) and ROUGE-L scores.
    """
    rouge = Rouge()
    bleu_scores = []
    rouge_l_scores = []
    smoothing_fn = SmoothingFunction().method1  # Use smoothing to avoid zero counts for higher n-grams

    for ref_text, gen_text in zip(ground_truth_texts, generated_texts):
        # Tokenize by splitting on whitespace
        ref_tokens = ref_text.split()
        gen_tokens = gen_text.split()
        # Calculate BLEU score for this sample with smoothing
        bleu = sentence_bleu([ref_tokens], gen_tokens, smoothing_function=smoothing_fn)
        bleu_scores.append(bleu)
        try:
            scores = rouge.get_scores(gen_text, ref_text)
            rouge_l_scores.append(scores[0]["rouge-l"]["f"])
        except ValueError:
            rouge_l_scores.append(0.0)

    avg_bleu = np.mean(bleu_scores) if bleu_scores else 0.0
    avg_rouge = np.mean(rouge_l_scores) if rouge_l_scores else 0.0
    return avg_bleu, avg_rouge

# Assuming you have collected all_ground_truth_texts and all_generated_texts during inference:
avg_bleu, avg_rouge = evaluate_generated_texts(all_ground_truth_texts, all_generated_texts)

print("=== Evaluation of Generated Reports ===")
print(f"Average BLEU Score (smoothed): {avg_bleu:.4f}")
print(f"Average ROUGE-L Score:         {avg_rouge:.4f}")

bertscore_metric = evaluate.load("bertscore")

def evaluate_with_bertscore(generated_texts, ground_truth_texts):
    results = bertscore_metric.compute(predictions=generated_texts, references=ground_truth_texts, lang="en")
    avg_f1 = np.mean(results["f1"])
    return avg_f1

avg_bertscore = evaluate_with_bertscore(all_generated_texts, all_ground_truth_texts)
print(f"Average BERTScore F1: {avg_bertscore:.4f}")


  A.Normalize(mean=[0.5,0.5,0.5], std=[0.5,0.5,0.5], always_apply=True),
  A.Normalize(mean=[0.5,0.5,0.5], std=[0.5,0.5,0.5], always_apply=True),
  A.Normalize(mean=[0.5,0.5,0.5], std=[0.5,0.5,0.5], always_apply=True),
Some weights of GPT2LMHeadModel were not initialized from the model checkpoint at gpt2 and are newly initialized: ['h.0.crossattention.c_attn.bias', 'h.0.crossattention.c_attn.weight', 'h.0.crossattention.c_proj.bias', 'h.0.crossattention.c_proj.weight', 'h.0.crossattention.q_attn.bias', 'h.0.crossattention.q_attn.weight', 'h.0.ln_cross_attn.bias', 'h.0.ln_cross_attn.weight', 'h.1.crossattention.c_attn.bias', 'h.1.crossattention.c_attn.weight', 'h.1.crossattention.c_proj.bias', 'h.1.crossattention.c_proj.weight', 'h.1.crossattention.q_attn.bias', 'h.1.crossattention.q_attn.weight', 'h.1.ln_cross_attn.bias', 'h.1.ln_cross_attn.weight', 'h.10.crossattention.c_attn.bias', 'h.10.crossattention.c_attn.weight', 'h.10.crossattention.c_proj.bias', 'h.10.crossattention.c_proj.wei

Epoch 1 Training Loss: 2.3529
Epoch 1 Validation Loss: 1.4922


Epoch 2/10: 100%|██████████| 18/18 [01:33<00:00,  5.17s/it, loss=1.1058]


Epoch 2 Training Loss: 1.3353
Epoch 2 Validation Loss: 1.1051


Epoch 3/10: 100%|██████████| 18/18 [01:31<00:00,  5.08s/it, loss=0.9250]


Epoch 3 Training Loss: 0.9782
Epoch 3 Validation Loss: 0.8755


Epoch 4/10: 100%|██████████| 18/18 [01:29<00:00,  4.96s/it, loss=0.8317]


Epoch 4 Training Loss: 0.8350
Epoch 4 Validation Loss: 0.8740


Epoch 5/10: 100%|██████████| 18/18 [01:30<00:00,  5.05s/it, loss=0.7283]


Epoch 5 Training Loss: 0.7457
Epoch 5 Validation Loss: 0.7147


Epoch 6/10: 100%|██████████| 18/18 [01:32<00:00,  5.13s/it, loss=0.5431]


Epoch 6 Training Loss: 0.5995
Epoch 6 Validation Loss: 0.6499


Epoch 7/10: 100%|██████████| 18/18 [01:30<00:00,  5.05s/it, loss=0.4917]


Epoch 7 Training Loss: 0.5261
Epoch 7 Validation Loss: 0.5871


Epoch 8/10: 100%|██████████| 18/18 [01:30<00:00,  5.04s/it, loss=0.4824]


Epoch 8 Training Loss: 0.4826
Epoch 8 Validation Loss: 0.5784


Epoch 9/10: 100%|██████████| 18/18 [01:31<00:00,  5.10s/it, loss=0.3822]


Epoch 9 Training Loss: 0.4830
Epoch 9 Validation Loss: 0.5370


Epoch 10/10: 100%|██████████| 18/18 [01:31<00:00,  5.06s/it, loss=0.4842]


Epoch 10 Training Loss: 0.4505
Epoch 10 Validation Loss: 0.4799
Training Complete!


A decoder-only architecture is being used, but right-padding was detected! For correct generation results, please set `padding_side='left'` when initializing the tokenizer.
A decoder-only architecture is being used, but right-padding was detected! For correct generation results, please set `padding_side='left'` when initializing the tokenizer.
A decoder-only architecture is being used, but right-padding was detected! For correct generation results, please set `padding_side='left'` when initializing the tokenizer.



=== Sample Generated Reports and Ground Truth Reports ===

[Ground Truth Report 1]:
[ Finding ] 
 
[ Diagnosis ] 
Rt. knee joint effusion 
Lt. accessory navicular bone, type II. 
 
LT. 1st MTP joint, bony erosion with periarticular osteopenia. 
 
Rt. 3rd finger DIP joint, bony erosion with soft tissue swelling. 
    
--> inflammatory arthritis, such as RA 
       D/Dx. infectious arthritis. 
 
[ Recommend ]

[Generated Report 1]:
FINDING 
no bony lesion. A possible inflammatory arthritis in the wrist of Ltantoaxial joint, both calcaneus and plantar area with suspicious erosion at base or accessory bone type II Rt 2nd MT shaft . - suggestive CTE involvement --- ==> RA

--------------------------------------------------
[Ground Truth Report 2]:
[FINDING       ] suspicious erosion at Lt 5th MTP joint 
-> r/o RA involvement   [CONCLUSION    ] suspicious erosion at Lt 5th MTP joint 
-> r/o RA involvement   [RECOMMENDATION] -

[Generated Report 2]:
FINDING 
No bony abnormality. No significa

Some weights of RobertaModel were not initialized from the model checkpoint at roberta-large and are newly initialized: ['roberta.pooler.dense.bias', 'roberta.pooler.dense.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


Average BERTScore F1: 0.8199


In [11]:
import os, pickle, random
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader, random_split
from tqdm import tqdm

# --- Import your model components and tokenizer ---
from timm import create_model
from transformers import GPT2TokenizerFast, GPT2LMHeadModel, GPT2Config

# --- Use Albumentations for image augmentation (as in your reference code) ---
import albumentations as A
from albumentations.pytorch import ToTensorV2
from PIL import Image

from nltk.translate.bleu_score import sentence_bleu,SmoothingFunction
from rouge import Rouge
import numpy as np
import evaluate

# --- Set up the tokenizer ---
tokenizer = GPT2TokenizerFast.from_pretrained("gpt2")
# GPT-2 does not have a pad token so we use EOS
tokenizer.pad_token = tokenizer.eos_token

# --- Load your data (pkl file) ---
with open('updated_merge_json_200x300.pkl', 'rb') as f:
    data = pickle.load(f)

# --- Balance the dataset ---
task_classes = ['ra', 'normal']
class_counts = {cls: 0 for cls in task_classes}
data_by_class = {cls: [] for cls in task_classes}

for key, entry in data.items():
    # use lower-case class label and only include if image file exists
    class_label = entry.get('class', '').lower()
    if class_label in class_counts and os.path.exists(entry['file_path']):
        class_counts[class_label] += 1
        data_by_class[class_label].append(entry)

min_class_count = min(class_counts.values())
balanced_data = []
for cls in task_classes:
    balanced_data.extend(random.sample(data_by_class[cls], min_class_count))

# Optionally, you may multiply the balanced_data for augmentation purposes:
augmented_data = balanced_data * 2

# --- Define Albumentations transforms ---
train_tfms = A.Compose([
    A.Resize(224, 224),
    A.HorizontalFlip(p=0.5),
    A.RandomBrightnessContrast(p=0.5),
    A.Normalize(mean=[0.5,0.5,0.5], std=[0.5,0.5,0.5], always_apply=True),
    ToTensorV2()
])

valid_tfms = A.Compose([
    A.Resize(224, 224),
    A.Normalize(mean=[0.5,0.5,0.5], std=[0.5,0.5,0.5], always_apply=True),
    ToTensorV2()
])

# For patch images we use a similar transform (you can adjust if needed)
patch_tfms = A.Compose([
    A.Resize(112, 112),
    A.Normalize(mean=[0.5,0.5,0.5], std=[0.5,0.5,0.5], always_apply=True),
    ToTensorV2()
])

# --- Define a new Dataset for report generation ---
class ReportDataset(Dataset):
    def __init__(self, data, img_tfms, patch_tfms, tokenizer, max_length=128):
        self.data = data
        self.img_tfms = img_tfms
        self.patch_tfms = patch_tfms
        self.tokenizer = tokenizer
        self.max_length = max_length

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        entry = self.data[idx]
        
        # Load the main image
        image = Image.open(entry['file_path']).convert('RGB')
        image = np.array(image)
        augmented = self.img_tfms(image=image)
        image_tensor = augmented['image']
        
        # Process patch images (if any)
        patches = entry.get('bbx', [])
        if len(patches) > 0:
            # Limit to up to 34 patches (as before)
            patch_imgs = [Image.fromarray(patch) for patch in patches[:34]]
            patch_tensors = []
            for p in patch_imgs:
                p = np.array(p)
                aug_patch = self.patch_tfms(image=p)
                patch_tensors.append(aug_patch['image'])
            # Concatenate along the channel dimension: resulting shape (34*3, 112, 112)
            combined_patches = torch.cat(patch_tensors, dim=0)
        else:
            combined_patches = torch.zeros(34*3, 112, 112)
        
        # Process the report text (diagnosis field)
        # Replace unwanted tokens and append EOS token (using tokenizer.eos_token)
        report_text = entry['diagnosis'].replace('_x000D_', ' ').strip()
        caption = f"{report_text}{self.tokenizer.eos_token}"
        
        # Tokenize without padding here (we will pad in the collate function)
        encoding = self.tokenizer(
            caption,
            truncation=True,
            max_length=self.max_length,
            return_attention_mask=True,
            return_tensors="pt"
        )
        # Squeeze to remove batch dimension (now shape: [seq_len])
        input_ids = encoding['input_ids'].squeeze(0)
        attention_mask = encoding['attention_mask'].squeeze(0)
        
        # Create shifted labels for LM training (shift left by one)
        # For example: if input_ids = [a, b, c, eos] then labels = [b, c, eos, eos]
        labels = input_ids.clone()
        if input_ids.size(0) > 1:
            labels[:-1] = input_ids[1:]
        # (Optionally, you can set the last token’s label to -100 so that loss is not computed)
        # labels[-1] = -100
        
        return image_tensor, combined_patches, input_ids, attention_mask, labels

# --- Define a collate function to pad variable-length sequences ---
def collate_fn(batch):
    images, patches, input_ids_list, attn_masks, labels_list = zip(*batch)
    images = torch.stack(images, dim=0)
    patches = torch.stack(patches, dim=0)
    
    # Use the tokenizer's pad method to pad the input_ids and labels
    batch_encoding = tokenizer.pad(
        {'input_ids': list(input_ids_list)},
        padding='longest',
        return_tensors='pt'
    )
    input_ids = batch_encoding['input_ids']
    
    batch_encoding_labels = tokenizer.pad(
        {'input_ids': list(labels_list)},
        padding='longest',
        return_tensors='pt'
    )
    labels = batch_encoding_labels['input_ids']
    
    # Pad attention masks similarly
    batch_attn = tokenizer.pad(
        {'input_ids': list(attn_masks)},
        padding='longest',
        return_tensors='pt'
    )
    attention_mask = batch_attn['input_ids']
    
    # Replace pad token positions in labels with -100 so that loss is not computed on them
    labels[ input_ids == tokenizer.pad_token_id ] = -100
    
    return images, patches, input_ids, attention_mask, labels

# --- Create Dataset and DataLoaders ---
full_dataset = ReportDataset(augmented_data, train_tfms, patch_tfms, tokenizer, max_length=128)

# You can split into train/val/test (here we use 80/10/10 split)
total_size = len(full_dataset)
train_size = int(0.8 * total_size)
val_size = int(0.1 * total_size)
test_size = total_size - train_size - val_size

train_dataset, val_dataset, test_dataset = random_split(full_dataset, [train_size, val_size, test_size])
batch_size = 16  # adjust as needed

train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, collate_fn=collate_fn)
val_loader   = DataLoader(val_dataset,   batch_size=batch_size, shuffle=False, collate_fn=collate_fn)
test_loader  = DataLoader(test_dataset,  batch_size=batch_size, shuffle=False, collate_fn=collate_fn)

# --- Use your model definition (TwoBranchWithGPT2) unchanged ---
# (Below is your unchanged model code; make sure it is defined exactly as you need)
class TwoBranchWithGPT2(nn.Module):
    def __init__(self, pretrained=True):
        super(TwoBranchWithGPT2, self).__init__()
        # -- SWIN Models --
        self.swin_global = create_model('swin_tiny_patch4_window7_224', pretrained=pretrained)
        self.swin_global.head = nn.Identity()
        self.swin_patch = create_model('swin_tiny_patch4_window7_224', pretrained=pretrained)
        self.swin_patch.head = nn.Identity()

        # -- ResNet Models --
        from torchvision import models
        resnet_global = models.resnet50(pretrained=pretrained)
        resnet_patch  = models.resnet50(pretrained=pretrained)
        resnet_global.fc = nn.Identity()
        resnet_patch.fc  = nn.Identity()
        self.resnet_global = resnet_global
        self.resnet_patch  = resnet_patch

        # -- Convert patch channels from 102 to 3 --
        self.patch_channel_reduction = nn.Conv2d(in_channels=102, out_channels=3, kernel_size=1)

        # -- Merge & project features to GPT2 hidden size (768) --
        self.feature_attention = nn.Sequential(
            nn.Linear(5632, 768),
            nn.ReLU()
        )

        # -- Classification head (unused in caption training) --
        self.classifier = nn.Sequential(
            nn.Linear(768, 512),
            nn.ReLU(),
            nn.Linear(512, 1),
            nn.Sigmoid()
        )

        # -- GPT-2 with cross-attention --
        gpt2_config = GPT2Config.from_pretrained("gpt2", add_cross_attention=True)
        self.gpt2 = GPT2LMHeadModel.from_pretrained("gpt2", config=gpt2_config)
        self.gpt2.resize_token_embeddings(len(tokenizer))
        # Initialize cross-attention layers
        for name, param in self.gpt2.named_parameters():
            if 'crossattention' in name:
                param.data.normal_(mean=0.0, std=0.02)

        # -- Prefix Projector --
        self.prefix_length = 10  # can experiment with this
        self.prefix_projector = nn.Linear(768, 768 * self.prefix_length)

    def encode_features(self, images, patches):
        # Resize patches to image size and reduce channels
        patches_resized = F.interpolate(patches, size=(224, 224), mode='bilinear', align_corners=False)
        patches_reduced = self.patch_channel_reduction(patches_resized)

        swin_global_features = self.swin_global.forward_features(images).mean(dim=[1, 2])
        swin_patch_features = self.swin_patch.forward_features(patches_reduced).mean(dim=[1, 2])
        resnet_global_features = self.resnet_global(images)
        resnet_patch_features = self.resnet_patch(patches_reduced)

        combined_features = torch.cat([swin_global_features, swin_patch_features,
                                         resnet_global_features, resnet_patch_features], dim=1)
        projected_features = self.feature_attention(combined_features)
        projected_features = F.normalize(projected_features, dim=-1)
        return projected_features

    def forward(self, images, patches, input_ids=None, attention_mask=None):
        projected_features = self.encode_features(images, patches)
        # We still compute classification output (unused in caption LM training)
        cls_output = self.classifier(projected_features)

        if input_ids is not None:
            batch_size = projected_features.size(0)
            encoder_hidden_states = self.prefix_projector(projected_features)
            encoder_hidden_states = encoder_hidden_states.view(batch_size, self.prefix_length, 768)
            # Teacher forcing mode: pass LM input tokens along with cross-attention
            gpt_outputs = self.gpt2(
                input_ids=input_ids,
                attention_mask=attention_mask,
                encoder_hidden_states=encoder_hidden_states
            )
            return cls_output, gpt_outputs.logits
        else:
            return cls_output

    def generate_reports(self, projected_features, tokenizer, max_length=50, 
                           do_sample=True, top_k=50, top_p=0.95, temperature=0.7):
        device = projected_features.device
        batch_size = projected_features.size(0)
        bos_id = tokenizer.bos_token_id
        eos_id = tokenizer.eos_token_id

        projected_features = F.normalize(projected_features, dim=-1)
        encoder_hidden_states = self.prefix_projector(projected_features)
        encoder_hidden_states = encoder_hidden_states.view(batch_size, self.prefix_length, 768)

        input_ids = torch.full((batch_size, 1), bos_id, device=device, dtype=torch.long)
        encoder_attention_mask = torch.ones((batch_size, self.prefix_length), device=device, dtype=torch.long)

        generated_ids = self.gpt2.generate(
            input_ids=input_ids,
            encoder_hidden_states=encoder_hidden_states,
            encoder_attention_mask=encoder_attention_mask,
            max_length=max_length,
            do_sample=do_sample,
            top_k=top_k,
            top_p=top_p,
            temperature=temperature,
            repetition_penalty=1.2,
            bos_token_id=bos_id,
            eos_token_id=eos_id,
            pad_token_id=tokenizer.pad_token_id
        )
        generated_texts = [tokenizer.decode(seq.tolist(), skip_special_tokens=True) for seq in generated_ids]
        return generated_texts

# --- Set up device, model, loss, optimizer, and scheduler ---
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = TwoBranchWithGPT2(pretrained=True).to(device)

# We use the text generation loss only (CrossEntropyLoss)
criterion_txt = nn.CrossEntropyLoss(ignore_index=-100)
optimizer = optim.AdamW(model.parameters(), lr=1e-4, weight_decay=1e-5)
from torch.optim.lr_scheduler import CosineAnnealingWarmRestarts
scheduler = CosineAnnealingWarmRestarts(optimizer, T_0=10, T_mult=2)

# --- Training Loop ---
num_epochs = 20  # adjust as needed
for epoch in range(num_epochs):
    model.train()
    running_loss = 0.0
    pbar = tqdm(train_loader, desc=f"Epoch {epoch+1}/{num_epochs}")
    for images, patches, input_ids, attention_mask, labels in pbar:
        images = images.to(device)
        patches = patches.to(device)
        input_ids = input_ids.to(device)
        attention_mask = attention_mask.to(device)
        labels = labels.to(device)

        optimizer.zero_grad()
        # We call the model in teacher-forcing mode.
        _, gpt_logits = model(images, patches, input_ids, attention_mask)
        # Flatten logits and labels for loss computation.
        loss = criterion_txt(gpt_logits.reshape(-1, gpt_logits.size(-1)), labels.reshape(-1))
        loss.backward()
        optimizer.step()
        scheduler.step()

        running_loss += loss.item()
        pbar.set_postfix({"loss": f"{loss.item():.4f}"})
    avg_loss = running_loss / len(train_loader)
    print(f"Epoch {epoch+1} Training Loss: {avg_loss:.4f}")

    # Validation step (optional)
    model.eval()
    val_loss = 0.0
    with torch.no_grad():
        for images, patches, input_ids, attention_mask, labels in val_loader:
            images = images.to(device)
            patches = patches.to(device)
            input_ids = input_ids.to(device)
            attention_mask = attention_mask.to(device)
            labels = labels.to(device)

            _, gpt_logits = model(images, patches, input_ids, attention_mask)
            loss = criterion_txt(gpt_logits.reshape(-1, gpt_logits.size(-1)), labels.reshape(-1))
            val_loss += loss.item()
    avg_val_loss = val_loss / len(val_loader)
    print(f"Epoch {epoch+1} Validation Loss: {avg_val_loss:.4f}")

print("Training Complete!")

# --- Testing / Inference: Generate Reports and Print Ground Truth ---
model.eval()
all_generated_texts = []
all_ground_truth_texts = []

with torch.no_grad():
    for images, patches, input_ids, attention_mask, labels in test_loader:
        images = images.to(device)
        patches = patches.to(device)
        # Get the encoded image features for caption generation.
        projected_features = model.encode_features(images, patches)
        # Generate reports for the current batch.
        gen_texts = model.generate_reports(projected_features, tokenizer, max_length=60)
        all_generated_texts.extend(gen_texts)
        
        # Decode ground truth reports from input_ids.
        # Note: We decode the first token sequence in each sample (removing padding tokens).
        for seq in input_ids:
            gt_text = tokenizer.decode(seq.tolist(), skip_special_tokens=True)
            all_ground_truth_texts.append(gt_text)

# Print a few sample generated reports along with the ground truth
print("\n=== Sample Generated Reports and Ground Truth Reports ===\n")
num_samples_to_show = 5
for i in range(num_samples_to_show):
    print(f"[Ground Truth Report {i+1}]:\n{all_ground_truth_texts[i]}\n")
    print(f"[Generated Report {i+1}]:\n{all_generated_texts[i]}\n")
    print("--------------------------------------------------")

def evaluate_generated_texts(ground_truth_texts, generated_texts):
    """
    Evaluate generated texts using BLEU (with smoothing) and ROUGE-L scores.
    """
    rouge = Rouge()
    bleu_scores = []
    rouge_l_scores = []
    smoothing_fn = SmoothingFunction().method1  # Use smoothing to avoid zero counts for higher n-grams

    for ref_text, gen_text in zip(ground_truth_texts, generated_texts):
        # Tokenize by splitting on whitespace
        ref_tokens = ref_text.split()
        gen_tokens = gen_text.split()
        # Calculate BLEU score for this sample with smoothing
        bleu = sentence_bleu([ref_tokens], gen_tokens, smoothing_function=smoothing_fn)
        bleu_scores.append(bleu)
        try:
            scores = rouge.get_scores(gen_text, ref_text)
            rouge_l_scores.append(scores[0]["rouge-l"]["f"])
        except ValueError:
            rouge_l_scores.append(0.0)

    avg_bleu = np.mean(bleu_scores) if bleu_scores else 0.0
    avg_rouge = np.mean(rouge_l_scores) if rouge_l_scores else 0.0
    return avg_bleu, avg_rouge

# Assuming you have collected all_ground_truth_texts and all_generated_texts during inference:
avg_bleu, avg_rouge = evaluate_generated_texts(all_ground_truth_texts, all_generated_texts)

print("=== Evaluation of Generated Reports ===")
print(f"Average BLEU Score (smoothed): {avg_bleu:.4f}")
print(f"Average ROUGE-L Score:         {avg_rouge:.4f}")

bertscore_metric = evaluate.load("bertscore")

def evaluate_with_bertscore(generated_texts, ground_truth_texts):
    results = bertscore_metric.compute(predictions=generated_texts, references=ground_truth_texts, lang="en")
    avg_f1 = np.mean(results["f1"])
    return avg_f1

avg_bertscore = evaluate_with_bertscore(all_generated_texts, all_ground_truth_texts)
print(f"Average BERTScore F1: {avg_bertscore:.4f}")


  A.Normalize(mean=[0.5,0.5,0.5], std=[0.5,0.5,0.5], always_apply=True),
  A.Normalize(mean=[0.5,0.5,0.5], std=[0.5,0.5,0.5], always_apply=True),
  A.Normalize(mean=[0.5,0.5,0.5], std=[0.5,0.5,0.5], always_apply=True),
Some weights of GPT2LMHeadModel were not initialized from the model checkpoint at gpt2 and are newly initialized: ['h.0.crossattention.c_attn.bias', 'h.0.crossattention.c_attn.weight', 'h.0.crossattention.c_proj.bias', 'h.0.crossattention.c_proj.weight', 'h.0.crossattention.q_attn.bias', 'h.0.crossattention.q_attn.weight', 'h.0.ln_cross_attn.bias', 'h.0.ln_cross_attn.weight', 'h.1.crossattention.c_attn.bias', 'h.1.crossattention.c_attn.weight', 'h.1.crossattention.c_proj.bias', 'h.1.crossattention.c_proj.weight', 'h.1.crossattention.q_attn.bias', 'h.1.crossattention.q_attn.weight', 'h.1.ln_cross_attn.bias', 'h.1.ln_cross_attn.weight', 'h.10.crossattention.c_attn.bias', 'h.10.crossattention.c_attn.weight', 'h.10.crossattention.c_proj.bias', 'h.10.crossattention.c_proj.wei

Epoch 1 Training Loss: 2.3149
Epoch 1 Validation Loss: 1.0593


Epoch 2/20: 100%|██████████| 18/18 [01:31<00:00,  5.11s/it, loss=1.3265]


Epoch 2 Training Loss: 1.2821
Epoch 2 Validation Loss: 0.7419


Epoch 3/20: 100%|██████████| 18/18 [01:30<00:00,  5.05s/it, loss=0.9980]


Epoch 3 Training Loss: 0.9197
Epoch 3 Validation Loss: 0.4890


Epoch 4/20: 100%|██████████| 18/18 [01:31<00:00,  5.10s/it, loss=0.7477]


Epoch 4 Training Loss: 0.7879
Epoch 4 Validation Loss: 0.4499


Epoch 5/20: 100%|██████████| 18/18 [01:31<00:00,  5.08s/it, loss=0.5928]


Epoch 5 Training Loss: 0.6756
Epoch 5 Validation Loss: 0.3632


Epoch 6/20: 100%|██████████| 18/18 [01:32<00:00,  5.15s/it, loss=0.6343]


Epoch 6 Training Loss: 0.5615
Epoch 6 Validation Loss: 0.3296


Epoch 7/20: 100%|██████████| 18/18 [01:30<00:00,  5.04s/it, loss=0.5875]


Epoch 7 Training Loss: 0.4910
Epoch 7 Validation Loss: 0.3130


Epoch 8/20: 100%|██████████| 18/18 [01:31<00:00,  5.10s/it, loss=0.5441]


Epoch 8 Training Loss: 0.4467
Epoch 8 Validation Loss: 0.2957


Epoch 9/20: 100%|██████████| 18/18 [01:32<00:00,  5.11s/it, loss=0.5347]


Epoch 9 Training Loss: 0.4639
Epoch 9 Validation Loss: 0.3035


Epoch 10/20: 100%|██████████| 18/18 [01:31<00:00,  5.09s/it, loss=0.3903]


Epoch 10 Training Loss: 0.3968
Epoch 10 Validation Loss: 0.2713


Epoch 11/20: 100%|██████████| 18/18 [01:32<00:00,  5.11s/it, loss=0.3418]


Epoch 11 Training Loss: 0.3535
Epoch 11 Validation Loss: 0.2619


Epoch 12/20: 100%|██████████| 18/18 [01:32<00:00,  5.15s/it, loss=0.3635]


Epoch 12 Training Loss: 0.3019
Epoch 12 Validation Loss: 0.2516


Epoch 13/20: 100%|██████████| 18/18 [01:31<00:00,  5.08s/it, loss=0.2650]


Epoch 13 Training Loss: 0.2691
Epoch 13 Validation Loss: 0.2288


Epoch 14/20: 100%|██████████| 18/18 [01:31<00:00,  5.09s/it, loss=0.2223]


Epoch 14 Training Loss: 0.2437
Epoch 14 Validation Loss: 0.2173


Epoch 15/20: 100%|██████████| 18/18 [01:31<00:00,  5.09s/it, loss=0.2419]


Epoch 15 Training Loss: 0.2302
Epoch 15 Validation Loss: 0.2132


Epoch 16/20: 100%|██████████| 18/18 [01:32<00:00,  5.12s/it, loss=0.2365]


Epoch 16 Training Loss: 0.2165
Epoch 16 Validation Loss: 0.2140


Epoch 17/20: 100%|██████████| 18/18 [01:30<00:00,  5.01s/it, loss=0.1750]


Epoch 17 Training Loss: 0.2025
Epoch 17 Validation Loss: 0.2137


Epoch 18/20: 100%|██████████| 18/18 [01:32<00:00,  5.15s/it, loss=0.3156]


Epoch 18 Training Loss: 0.2205
Epoch 18 Validation Loss: 0.2246


Epoch 19/20: 100%|██████████| 18/18 [01:30<00:00,  5.05s/it, loss=0.2361]


Epoch 19 Training Loss: 0.2138
Epoch 19 Validation Loss: 0.2169


Epoch 20/20: 100%|██████████| 18/18 [01:31<00:00,  5.10s/it, loss=0.1771]


Epoch 20 Training Loss: 0.2116
Epoch 20 Validation Loss: 0.2275
Training Complete!


A decoder-only architecture is being used, but right-padding was detected! For correct generation results, please set `padding_side='left'` when initializing the tokenizer.
A decoder-only architecture is being used, but right-padding was detected! For correct generation results, please set `padding_side='left'` when initializing the tokenizer.
A decoder-only architecture is being used, but right-padding was detected! For correct generation results, please set `padding_side='left'` when initializing the tokenizer.



=== Sample Generated Reports and Ground Truth Reports ===

[Ground Truth Report 1]:
[ Finding ] 
 
[ Diagnosis ] 
1. Two separate ossicles with sclerotic margin in right medial hallux sessamoid bone. 
 - Old fracture or bipartite sessamoid bone. 
 REC) Clinical correlation. 
2. Mild soft tissue swelling in medial portion of right foot. 
[ Recommend ]

[Generated Report 1]:
 Finding ] 
no bony lesion. [ Diagnosis - IID or XXL joint, left] no significant body abnormality on radiographs of rt 2nd MT shaft since last study in 2018-11/14 . ---> possible RA involvement with multiple inflammatory arthritis such as

--------------------------------------------------
[Ground Truth Report 2]:
[FINDING       ] joint space narrowing, possible erosions, Rt TMT joints 
-> r/o RA involvement 

degenerative change, talonavicular joint, Rt 

flat foot, Rt   [CONCLUSION    ] joint space narrowing, possible erosions, Rt TMT joints 
-> r/o RA involvement 

degenerative change, talonavicular joint, Rt 

f

Some weights of RobertaModel were not initialized from the model checkpoint at roberta-large and are newly initialized: ['roberta.pooler.dense.bias', 'roberta.pooler.dense.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


Average BERTScore F1: 0.8320
