In [4]:
import copy
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision.transforms as T
import torchvision.models as models
from sklearn.metrics import roc_auc_score, roc_curve
from tqdm import tqdm
import logging
import time
import matplotlib.pyplot as plt
import numpy as np
from transformers import AutoTokenizer, AutoModel
from torch.utils.data import DataLoader
import yaml
from data import DataLoader as CustomDataLoader

torch.cuda.empty_cache()

# Initialize logging
logging.basicConfig(filename='training.log', level=logging.INFO, 
                    format='%(asctime)s - %(levelname)s - %(message)s')

# Device configuration
device = torch.device("cuda:1" if torch.cuda.is_available() else "cpu")

# Load configuration
config_file = "config1.yaml"
with open(config_file, 'r') as f:
    config = yaml.safe_load(f)
config['data_pct'] = 100

# Data loading
data_ins = CustomDataLoader(config)
train_loader, valid_loader, test_loader = data_ins.GetMimicDataset()

# Define custom BYOL model
class ProjectionHead(nn.Module):
    def __init__(self, in_dim, hidden_dim, out_dim):
        super(ProjectionHead, self).__init__()
        self.block = nn.Sequential(
            nn.Linear(in_dim, hidden_dim),
            nn.BatchNorm1d(hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, out_dim)
        )

    def forward(self, x):
        return self.block(x)

class PredictionHead(nn.Module):
    def __init__(self, in_dim, hidden_dim, out_dim):
        super(PredictionHead, self).__init__()
        self.block = nn.Sequential(
            nn.Linear(in_dim, hidden_dim),
            nn.BatchNorm1d(hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, out_dim)
        )

    def forward(self, x):
        return self.block(x)

class BYOL(nn.Module):
    def __init__(self, backbone):
        super(BYOL, self).__init__()
        self.backbone = backbone
        self.projection_head = ProjectionHead(2048, 4096, 256)
        self.prediction_head = PredictionHead(256, 4096, 256)

        self.backbone_momentum = copy.deepcopy(self.backbone)
        self.projection_head_momentum = copy.deepcopy(self.projection_head)

        for param in self.backbone_momentum.parameters():
            param.requires_grad = False
        for param in self.projection_head_momentum.parameters():
            param.requires_grad = False

    def forward(self, x):
        y = self.backbone(x).flatten(start_dim=1)
        z = self.projection_head(y)
        p = self.prediction_head(z)
        return p

    def forward_momentum(self, x):
        y = self.backbone_momentum(x).flatten(start_dim=1)
        z = self.projection_head_momentum(y)
        z = z.detach()
        return z

def negative_cosine_similarity(p, z):
    return -F.cosine_similarity(p, z.detach(), dim=-1).mean()

def vicreg_loss(x, y, sim_weight=25.0, var_weight=25.0, cov_weight=1.0):
    repr_loss = F.mse_loss(x, y)

    x = x - x.mean(dim=0)
    y = y - y.mean(dim=0)
    
    std_x = torch.sqrt(x.var(dim=0) + 1e-4)
    std_y = torch.sqrt(y.var(dim=0) + 1e-4)
    std_loss = (torch.mean(F.relu(1 - std_x)) + torch.mean(F.relu(1 - std_y))) * var_weight
    
    cov_x = (x.T @ x) / (x.size(0) - 1)
    cov_y = (y.T @ y) / (y.size(0) - 1)
    cov_loss = (off_diagonal(cov_x).pow_(2).sum() + off_diagonal(cov_y).pow_(2).sum()) * cov_weight
    
    return sim_weight * repr_loss + std_loss + cov_loss

def off_diagonal(x):
    n, m = x.shape
    assert n == m
    return x.flatten()[:-1].view(n - 1, n + 1)[:, 1:].flatten()

# Initialize BYOL
resnet = models.resnet50(weights=models.ResNet50_Weights.IMAGENET1K_V1).to(device)
backbone = nn.Sequential(*list(resnet.children())[:-1])
byol_model = BYOL(backbone).to(device)

# BioBERT
tokenizer = AutoTokenizer.from_pretrained("dmis-lab/biobert-base-cased-v1.1")
biobert_model = AutoModel.from_pretrained("dmis-lab/biobert-base-cased-v1.1").to(device)

# Define combined model
class CombinedModel(nn.Module):
    def __init__(self, image_model, text_model, image_feature_dim, text_feature_dim, hidden_dim=512, num_classes=15):
        super(CombinedModel, self).__init__()
        self.image_model = image_model
        self.text_model = text_model
        self.fc_image = nn.Linear(image_feature_dim, hidden_dim)
        self.fc_text = nn.Linear(text_feature_dim, hidden_dim)
        self.classifier = nn.Linear(hidden_dim, num_classes)
    
    def forward(self, images, input_ids, attention_mask):
        image_features = self.image_model(images)
        text_outputs = self.text_model(input_ids=input_ids, attention_mask=attention_mask)
        text_features = text_outputs.last_hidden_state[:, 0, :]
        
        combined_features = F.relu(self.fc_image(image_features)) + F.relu(self.fc_text(text_features))
        output = torch.sigmoid(self.classifier(combined_features))
        return output

# Instantiate combined model
image_feature_dim = 256  # Adjust based on your BYOL output dimension
text_feature_dim = 768  # BioBERT output dimension
combined_model = CombinedModel(byol_model, biobert_model, image_feature_dim, text_feature_dim).to(device)

# Training and validation setup
num_epochs = 10
learning_rate = 0.001
optimizer = torch.optim.Adam(combined_model.parameters(), lr=learning_rate)
classification_criterion = nn.BCELoss()

# Training loop for the combined model
total_start_time = time.time()
roc_auc_scores = []

for epoch in range(num_epochs):
    combined_model.train()
    epoch_loss = 0
    for batch in tqdm(train_loader):
        if len(batch) == 3:
            images, text, labels = batch
            input_ids = text['input_ids']
            attention_mask = text['attention_mask']
            input_ids = input_ids.to(device)
            attention_mask = attention_mask.to(device)
        elif len(batch) == 2:
            images, labels = batch
            input_ids = None
            attention_mask = None
        else:
            raise ValueError(f"Unexpected batch structure: {len(batch)} elements")

        # Ensure images are of correct shape and convert to tensor
        if isinstance(images, list):
            images = torch.stack(images)
        
        images = images.squeeze()
        if len(images.shape) == 3:
            images = images.unsqueeze(0)
        
        images = images.to(device)
        labels = labels.to(device)

        optimizer.zero_grad()

        # Forward pass
        outputs = combined_model(images, input_ids, attention_mask)

        # Calculate classification loss
        classification_loss = classification_criterion(outputs, labels)

        # Calculate BYOL losses
        p1 = byol_model(images)
        z1 = byol_model.forward_momentum(images)
        loss_byol = negative_cosine_similarity(p1, z1)
        
        # Calculate VICReg variance losses
        variance_I = vicreg_loss(p1, z1)
        variance_T = vicreg_loss(text_features, text_features)
        loss_vicreg = F.mse_loss(variance_I, variance_T)

        # Combined loss
        loss = (classification_loss + loss_byol + loss_vicreg) / 3
        loss.backward()
        optimizer.step()

        epoch_loss += loss.item()

    logging.info(f"Epoch [{epoch+1}/{num_epochs}], Classification Loss: {epoch_loss/len(train_loader):.4f}")

    # Validation
    combined_model.eval()
    all_labels = []
    all_preds = []
    with torch.no_grad():
        for batch in tqdm(valid_loader):
            if len(batch) == 3:
                images, text, labels = batch
                input_ids = text['input_ids']
                attention_mask = text['attention_mask']
                input_ids = input_ids.to(device)
                attention_mask = attention_mask.to(device)
            elif len(batch) == 2:
                images, labels = batch
                input_ids = None
                attention_mask = None
            else:
                raise ValueError(f"Unexpected batch structure: {len(batch)} elements")

            # Ensure images are of correct shape and convert to tensor
            if isinstance(images, list):
                images = torch.stack(images)
            
            images = images.squeeze()
            if len(images.shape) == 3:
                images = images.unsqueeze(0)

            images = images.to(device)
            labels = labels.to(device)

            outputs = combined_model(images, input_ids, attention_mask)
            all_labels.append(labels.cpu().numpy())
            all_preds.append(outputs.cpu().numpy())

    all_labels = np.concatenate(all_labels)
    all_preds = np.concatenate(all_preds)
    roc_auc = roc_auc_score(all_labels, all_preds, average=None)
    roc_auc_scores.append(roc_auc)

    logging.info(f"Epoch [{epoch+1}/{num_epochs}], Validation ROC AUC: {roc_auc}")

total_end_time = time.time()
total_duration = total_end_time - total_start_time
logging.info(f"Total Training Time: {total_duration:.2f} seconds")

# Plot ROC AUC scores
plt.figure()
for i, scores in enumerate(np.array(roc_auc_scores).T):
    plt.plot(scores, label=f'Class {i}')
plt.xlabel('Epoch')
plt.ylabel('ROC AUC Score')
plt.legend()
plt.title('ROC AUC Score per Class')
plt.savefig('roc_auc_scores.png')

Unexpected exception formatting exception. Falling back to standard exception


Traceback (most recent call last):
  File "/opt/conda/lib/python3.10/site-packages/IPython/core/interactiveshell.py", line 3553, in run_code
    exec(code_obj, self.user_global_ns, self.user_ns)
  File "/tmp/ipykernel_234433/173065181.py", line 16, in <module>
    from data import DataLoader as CustomDataLoader
  File "/workspace/hardik/data/__init__.py", line 1, in <module>
    from .data_loader import DataLoader
  File "/workspace/hardik/data/data_loader.py", line 28, in <module>
    from albumentations import Compose, Normalize, Resize, ShiftScaleRotate
  File "/opt/conda/lib/python3.10/site-packages/albumentations/__init__.py", line 6, in <module>
    from .augmentations import *
  File "/opt/conda/lib/python3.10/site-packages/albumentations/augmentations/__init__.py", line 2, in <module>
    from .blur.transforms import *
  File "/opt/conda/lib/python3.10/site-packages/albumentations/augmentations/blur/__init__.py", line 2, in <module>
    from .transforms import *
  File "/opt/co