In [10]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision.transforms as T
import torchvision.models as models
from sklearn.metrics import roc_auc_score, roc_curve
from tqdm import tqdm
import logging
import time
import matplotlib.pyplot as plt
import numpy as np
from transformers import AutoTokenizer, AutoModel
from torch.utils.data import DataLoader
import yaml
from data import DataLoader as CustomDataLoader

torch.cuda.empty_cache()

# Initialize logging
logging.basicConfig(filename='training.log', level=logging.INFO, 
                    format='%(asctime)s - %(levelname)s - %(message)s')

# Device configuration
device = torch.device("cuda:1" if torch.cuda.is_available() else "cpu")

# Load configuration
config_file = "config1.yaml"
with open(config_file, 'r') as f:
    config = yaml.safe_load(f)
config['data_pct'] = 100

# Data loading
data_ins = CustomDataLoader(config)
train_loader, valid_loader, test_loader = data_ins.GetMimicDataset()

# Define the ResNet18-based model with BYOL
class BYOL(nn.Module):
    def __init__(self, base_encoder, hidden_dim=4096, projection_dim=256, num_classes=15, moving_average_decay=0.99):
        super(BYOL, self).__init__()
        self.base_encoder = base_encoder
        self.projection_dim = projection_dim
        self.hidden_dim = hidden_dim
        self.num_classes = num_classes

        # Determine the output size from base_encoder
        with torch.no_grad():
            dummy_input = torch.zeros(1, 3, 224, 224).to(device)
            output_size = self.base_encoder(dummy_input).view(1, -1).size(1)

        self.online_encoder = nn.Sequential(
            self.base_encoder,
            nn.Linear(output_size, hidden_dim),
            nn.BatchNorm1d(hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, projection_dim)
        )

        self.target_encoder = nn.Sequential(
            self.base_encoder,
            nn.Linear(output_size, hidden_dim),
            nn.BatchNorm1d(hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, projection_dim)
        )

        for param_online, param_target in zip(self.online_encoder.parameters(), self.target_encoder.parameters()):
            param_target.data.copy_(param_online.data)
            param_target.requires_grad = False

        self.moving_average_decay = moving_average_decay

        self.classifier = nn.Sequential(
            nn.Linear(output_size, num_classes),
            nn.Sigmoid()
        )

    def forward(self, x1, x2=None):
        if x2 is None:
            return self.classifier(self.base_encoder(x1))

        online_proj_one = self.online_encoder(x1)
        online_proj_two = self.online_encoder(x2)
        target_proj_one = self.target_encoder(x1).detach()
        target_proj_two = self.target_encoder(x2).detach()
        return online_proj_one, online_proj_two, target_proj_one, target_proj_two

    def update_target_network(self):
        for param_online, param_target in zip(self.online_encoder.parameters(), self.target_encoder.parameters()):
            param_target.data = self.moving_average_decay * param_target.data + (1 - self.moving_average_decay) * param_online.data

# Define the BYOL loss function
def byol_loss(p1, p2, z1, z2):
    loss_one = 2 - 2 * (p1 * z2.detach()).sum(dim=-1)
    loss_two = 2 - 2 * (p2 * z1.detach()).sum(dim=-1)
    return (loss_one + loss_two).mean()

# Define transformations for BYOL
byol_transforms = T.Compose([
    T.RandomHorizontalFlip(p=0.5),
    T.RandomVerticalFlip(p=0.5),
    T.RandomRotation(30),
    T.RandomResizedCrop(224, scale=(0.8, 1.0)),
    T.ColorJitter(brightness=0.4, contrast=0.4, saturation=0.4, hue=0.2),
    T.ToTensor(),
    T.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])

# Model initialization
num_classes = 15
base_encoder = models.resnet18(weights=models.ResNet18_Weights.IMAGENET1K_V1).to(device)
byol_model = BYOL(base_encoder, hidden_dim=4096, projection_dim=256, num_classes=num_classes).to(device)

# Initialize BioBERT
tokenizer = AutoTokenizer.from_pretrained("dmis-lab/biobert-base-cased-v1.1")
biobert_model = AutoModel.from_pretrained("dmis-lab/biobert-base-cased-v1.1").to(device)

# Define combined model
class CombinedModel(nn.Module):
    def __init__(self, image_model, text_model, image_feature_dim, text_feature_dim, hidden_dim=512, num_classes=15):
        super(CombinedModel, self).__init__()
        self.image_model = image_model
        self.text_model = text_model
        self.fc_image = nn.Linear(image_feature_dim, hidden_dim)
        self.fc_text = nn.Linear(text_feature_dim, hidden_dim)
        self.classifier = nn.Linear(hidden_dim, num_classes)
    
    def forward(self, images, input_ids, attention_mask):
        image_features = self.image_model(images)
        text_outputs = self.text_model(input_ids=input_ids, attention_mask=attention_mask)
        text_features = text_outputs.last_hidden_state[:, 0, :]
        
        combined_features = F.relu(self.fc_image(image_features)) + F.relu(self.fc_text(text_features))
        output = torch.sigmoid(self.classifier(combined_features))
        return output

# Instantiate combined model
image_feature_dim = 4096  # Change based on your BYOL output dimension
text_feature_dim = 768  # BioBERT output dimension
combined_model = CombinedModel(byol_model, biobert_model, image_feature_dim, text_feature_dim).to(device)

# Training and validation setup
num_epochs = 10
learning_rate = 0.001
optimizer = torch.optim.Adam(combined_model.parameters(), lr=learning_rate)
classification_criterion = nn.BCELoss()

# Training loop for the combined model
total_start_time = time.time()
roc_auc_scores = []

for epoch in range(num_epochs):
    combined_model.train()
    epoch_loss = 0
    for batch in tqdm(train_loader):
        if len(batch) == 3:
            images, text, labels = batch
            input_ids = text['input_ids']
            attention_mask = text['attention_mask']
            input_ids = input_ids.to(device)
            attention_mask = attention_mask.to(device)
        elif len(batch) == 2:
            images, labels = batch
            input_ids = None
            attention_mask = None
        else:
            raise ValueError(f"Unexpected batch structure: {len(batch)} elements")

        images = images.to(device)
        labels = labels.to(device)

        optimizer.zero_grad()

        outputs = combined_model(images, input_ids, attention_mask)
        loss = classification_criterion(outputs, labels)

        loss.backward()
        optimizer.step()

        epoch_loss += loss.item()

    logging.info(f"Epoch [{epoch+1}/{num_epochs}], Classification Loss: {epoch_loss/len(train_loader):.4f}")

    # Validation
    combined_model.eval()
    all_labels = []
    all_preds = []
    with torch.no_grad():
        for batch in tqdm(valid_loader):
            if len(batch) == 3:
                images, text, labels = batch
                input_ids = text['input_ids']
                attention_mask = text['attention_mask']
                input_ids = input_ids.to(device)
                attention_mask = attention_mask.to(device)
            elif len(batch) == 2:
                images, labels = batch
                input_ids = None
                attention_mask = None
            else:
                raise ValueError(f"Unexpected batch structure: {len(batch)} elements")

            images = images.to(device)
            labels = labels.to(device)

            outputs = combined_model(images, input_ids, attention_mask)
            all_labels.append(labels.cpu().numpy())
            all_preds.append(outputs.cpu().numpy())

    all_labels = np.concatenate(all_labels)
    all_preds = np.concatenate(all_preds)
    roc_auc = roc_auc_score(all_labels, all_preds, average=None)
    roc_auc_scores.append(roc_auc)

    logging.info(f"Epoch [{epoch+1}/{num_epochs}], Validation ROC AUC: {roc_auc}")

total_end_time = time.time()
total_duration = total_end_time - total_start_time
logging.info(f"Total Training Time: {total_duration:.2f} seconds")

# Save the trained model
torch.save(combined_model.state_dict(), "combined_model.pth")

# Plot ROC AUC scores
plt.figure(figsize=(10, 8))
for i in range(num_classes):
    plt.plot([roc_auc[i] for roc_auc in roc_auc_scores], label=f'Class {i}')
plt.xlabel('Epoch')
plt.ylabel('ROC AUC')
plt.title('ROC AUC Scores per Epoch')
plt.legend()
plt.grid(True)
plt.show()

213357 images have loaded for training
4774 images have loaded for validation
4774 images have loaded for testing


  0%|                                                                                                                           | 0/3333 [00:05<?, ?it/s]


AttributeError: 'list' object has no attribute 'to'