In [1]:
import os
import copy
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision.transforms as T
import torchvision.models as models
from sklearn.metrics import roc_auc_score, roc_curve
from tqdm import tqdm
import logging
import time
import matplotlib.pyplot as plt
import numpy as np
import torch.backends.cudnn as cudnn

from transformers import AutoTokenizer, AutoModel
from torch.utils.data import DataLoader
import yaml
from optimizer import LARS
from utils import log, AverageMeter, collect_params

# from data import DataLoader as CustomDataLoader
from data import DataLoader
os.environ["TOKENIZERS_PARALLELISM"] = "false"

In [2]:
from byol import BYOL

In [3]:
# # Data loading
# data_ins   = DataLoader(config)
# # train_loader, valid_loader, test_loader = data_ins.GetMimicDataset() #you are not supposed to use the mimic dataset
# #you need to use the multimodal data
# train_loader, valid_loader = data_ins.GetMultimodalPretrainingDataset()

In [4]:
# # Define custom BYOL model
# class ProjectionHead(nn.Module):
#     def __init__(self, in_dim, hidden_dim, out_dim):
#         super(ProjectionHead, self).__init__()
#         self.block = nn.Sequential(
#             nn.Linear(in_dim, hidden_dim),
#             nn.BatchNorm1d(hidden_dim),
#             nn.ReLU(),
#             nn.Linear(hidden_dim, out_dim)
#         )

#     def forward(self, x):
#         return self.block(x)

# class PredictionHead(nn.Module):
#     def __init__(self, in_dim, hidden_dim, out_dim):
#         super(PredictionHead, self).__init__()
#         self.block = nn.Sequential(
#             nn.Linear(in_dim, hidden_dim),
#             nn.BatchNorm1d(hidden_dim),
#             nn.ReLU(),
#             nn.Linear(hidden_dim, out_dim)
#         )

#     def forward(self, x):
#         return self.block(x)

# class BYOL(nn.Module):
#     def __init__(self, backbone):
#         super(BYOL, self).__init__()
#         self.backbone = backbone
#         self.projection_head = ProjectionHead(2048, 4096, 256)
#         self.prediction_head = PredictionHead(256, 4096, 256)

#         self.backbone_momentum = copy.deepcopy(self.backbone)
#         self.projection_head_momentum = copy.deepcopy(self.projection_head)

#         for param in self.backbone_momentum.parameters():
#             param.requires_grad = False
#         for param in self.projection_head_momentum.parameters():
#             param.requires_grad = False

#     def forward_online(self, x):
#         y = self.backbone(x).flatten(start_dim=1)
#         z = self.projection_head(y)
#         p = self.prediction_head(z)
#         return p

# #     def forward_momentum(self, x):
#     def forward_target(self, x):
#         y = self.backbone_momentum(x).flatten(start_dim=1)
#         z = self.projection_head_momentum(y)
#         z = z.detach()
#         return z
    
#     def forward(self, x):
#         online = self.forward_online(x)
#         target = self.forward_target(x)
        
#         return online, target
    

# def negative_cosine_similarity(p, z):
#     return -F.cosine_similarity(p, z.detach(), dim=-1).mean()

# def vicreg_loss(x, y, sim_weight=25.0, var_weight=25.0, cov_weight=1.0):
# #     repr_loss = F.mse_loss(x, y)

#     x = x - x.mean(dim=0)
#     y = y - y.mean(dim=0)
    
#     std_x = torch.sqrt(x.var(dim=0) + 1e-4)
#     std_y = torch.sqrt(y.var(dim=0) + 1e-4)
#     std_loss = (torch.mean(F.relu(1 - std_x)) + torch.mean(F.relu(1 - std_y))) * var_weight
    
#     cov_x = (x.T @ x) / (x.size(0) - 1)
#     cov_y = (y.T @ y) / (y.size(0) - 1)
#     cov_loss = (off_diagonal(cov_x).pow_(2).sum() + off_diagonal(cov_y).pow_(2).sum()) * cov_weight
    
#     return std_loss + cov_loss

# def off_diagonal(x):
#     n, m = x.shape
#     assert n == m
#     return x.flatten()[:-1].view(n - 1, n + 1)[:, 1:].flatten()

In [5]:
# # BYOL
# resnet = models.resnet50(weights=models.ResNet50_Weights.IMAGENET1K_V1).to(device)
# backbone = nn.Sequential(*list(resnet.children())[:-1]).to(device) ##added .to(device)
# byol_model = BYOL(backbone).to(device)

In [6]:
# class TextProjectionHead(nn.Module):
#     def __init__(self):
#         super().__init__()
#         embedding_dim= 768
#         projection_dim=256
#         dropout=0.2
        
#         self.projection = nn.Linear(embedding_dim, projection_dim)
#         self.gelu       = nn.GELU()
#         self.fc         = nn.Linear(projection_dim, projection_dim)
#         self.dropout    = nn.Dropout(dropout)
#         self.layer_norm = nn.LayerNorm(projection_dim)
    
#     def forward(self, x):
#         projected = self.projection(x)
#         x = self.gelu(projected)
#         x = self.fc(x)
#         x = self.dropout(x)
#         x = x + projected
#         x = self.layer_norm(x)
#         return x

In [7]:
class TextEncoder(nn.Module):
    def __init__(self ):
        super().__init__()

        model_name="emilyalsentzer/Bio_ClinicalBERT"
        pretrained=True
        trainable=False
        
        if pretrained:
            self.model = AutoModel.from_pretrained(model_name)
        for p in self.model.parameters():
            p.requires_grad = trainable

        # we are using the CLS token hidden representation as the sentence's embedding
        self.target_token_idx = 0

    def forward(self, input_ids, attention_mask):
        _, pooled_output = self.model(input_ids=input_ids, attention_mask=attention_mask,return_dict=False)
        return pooled_output

In [8]:
# class CombinedModel(nn.Module):
#     def __init__(self, image_model, text_model):
#         super(CombinedModel, self).__init__()
#         self.image_model = image_model
#         self.text_model = text_model
#         # self.fc = nn.Linear(256, 1)##changed
    
#     def forward(self, images, input_ids, attention_mask):
#         online, target = self.image_model(images)
#         text_features = self.text_model(input_ids,attention_mask)
#         # outputs = self.fc(text_features) ##changed
#         return online, target, text_features

In [9]:
# biobert_model = TextEncoder()
# combined_model = CombinedModel(byol_model, biobert_model).to(device)

In [10]:
# # Training and validation
# num_epochs = 10
# learning_rate = 0.001
# optimizer = torch.optim.Adam(combined_model.parameters(), lr=learning_rate)
# classification_criterion = nn.BCELoss()

# # Training loop for the combined model
# total_start_time = time.time()
# roc_auc_scores = []

In [11]:
# torch.cuda.empty_cache()

# Initialize logging
# logging.basicConfig(filename='training.log', level=logging.INFO, 
#                     format='%(asctime)s - %(levelname)s - %(message)s')

# Device configuration
# device = torch.device("cuda:1" if torch.cuda.is_available() else "cpu")



In [12]:
# Training loop for the combined model
# num_epochs = 300
class Trainer():
    def __init__(self,config):
        self.config = config
        self.total_epochs  = config['optimizer']['total_epochs']
        self.warmup_epochs = config['optimizer']['warmup_epochs']
        self.batch_size = config['pre_bs']
        
        self.data_ins = DataLoader(config)
        self.train_loader, self.valid_loader = self.data_ins.GetMultimodalPretrainingDataset()        
        
        num_examples = len(self.train_loader)*self.batch_size
        self.warmup_steps  = self.warmup_epochs * num_examples//self.batch_size    
        self.total_steps   = self.total_epochs * num_examples //self.batch_size

        self.base_lr   = self.config['optimizer']['base_lr']/256
        self.max_lr    = self.base_lr * self.batch_size
        
        self.base_mm   = self.config['model']['base_momentum']
        self.gpu       = self.config['gpu']
        
        self.resume_path = self.config['checkpoint']['resume_path']
        if torch.cuda.is_available():
            self.device = torch.device(f'cuda:{self.gpu}')
            torch.cuda.set_device(self.device)
            cudnn.benchmark = True
        else:
            self.device = torch.device('cpu')
            
        torch.cuda.empty_cache()
        self.model_name = self.config['model']['model_name']
        self.dataset = self.config['dataset']
        
        save_path   = os.path.join('./ckpt',self.model_name.lower())
        self.method_name = f"{config['model']['backbone']['type']}_{self.dataset}_{self.batch_size}_{self.total_epochs}"
        self.config['checkpoint']['ckpt_path'] = os.path.join(save_path,self.method_name)        
        os.makedirs(config['checkpoint']['ckpt_path'], exist_ok=True)
        self.logger = log(path=config['checkpoint']['ckpt_path'], file=f"{self.method_name}.logs")
        
        """log tools in the running phase"""
        self.steps = 0
        self.total_training_time = 0
        self.log_step   = self.config['checkpoint']['log_step']
        self.save_epoch = self.config['checkpoint']['save_epoch']
        self.construct_model()
        
    def construct_model(self):
        self.logger.info("init model!")
                
        byol_model = BYOL(self.config)
        self.image_model = byol_model.to(self.device)
        self.logger.info(self.image_model)
        
        self.text_encoder     = TextEncoder().to(self.device)
#         self.text_projection  = TextProjectionHead().to(self.device)        
        self.logger.info(self.text_encoder)
#         self.logger.info(self.text_projection)


        self.logger.info("get optimizer!")
        momentum = self.config['optimizer']['momentum']
        weight_decay = self.config['optimizer']['weight_decay']
        exclude_bias_and_bn = self.config['optimizer']['exclude_bias_and_bn']
        params = collect_params([self.image_model.online_network, self.image_model.predictor],exclude_bias_and_bn=exclude_bias_and_bn)
        self.optimizer = LARS(params, lr=self.max_lr,momentum=momentum, weight_decay=weight_decay)
        
    def resume_model(self, model_path=None):
        if model_path is None and not self.resume_path:
            self.start_epoch = 0
            self.logger.info("--> No loaded checkpoint!")
        else:
            model_path = model_path or self.resume_path
            checkpoint = torch.load(model_path, map_location=self.device)

            self.start_epoch = checkpoint['epoch']
            self.steps = checkpoint['steps']
            self.model.load_state_dict(checkpoint['model'], strict=True)
            self.optimizer.load_state_dict(checkpoint['optimizer'])
            self.logger.info(f"--> Loaded checkpoint '{model_path}' (epoch {self.start_epoch})")

    # save snapshots
    def save_checkpoint(self, epoch):
        if epoch % self.save_epoch == 0:
            model_state = {'config': self.config,
                           'epoch': epoch,
                           'steps': self.steps,
                           'model': self.image_model.state_dict(),
                           'online': self.image_model.online_network.state_dict(),
                           'optimizer': self.optimizer.state_dict(),
                     }
            online_state = {'online': self.image_model.online_network.state_dict()}
            SAVE_PATH1 = os.path.join(self.config['checkpoint']['ckpt_path'], f'{self.method_name}.pth')
            SAVE_PATH2 = os.path.join(self.config['checkpoint']['ckpt_path'], f'{self.method_name}_{epoch}.pth')
            torch.save(model_state, SAVE_PATH1)
            torch.save(online_state, SAVE_PATH2)
            
    def adjust_learning_rate(self, step):
        """learning rate warm up and decay"""
        max_lr = self.max_lr
        min_lr = 1e-3 * self.max_lr
        if step < self.warmup_steps:
            lr = (max_lr - min_lr) * step / self.warmup_steps + min_lr
        else:
            lr = min_lr + 0.5 * \
                (max_lr - min_lr) * (1 + np.cos((step - self.warmup_steps) * np.pi / self.total_steps))
        for param_group in self.optimizer.param_groups:
            param_group['lr'] = lr

    def adjust_mm(self, step):
        self.mm = 1 - (1 - self.base_mm) * \
            (np.cos(np.pi * step / self.total_steps) + 1) / 2
    
    def regression_loss(self, preds, targets):
        bz = preds.size(0)
        preds_norm = F.normalize(preds, dim=1)
        targets_norm = F.normalize(targets, dim=1)
        loss = 2 - 2 * (preds_norm * targets_norm).sum() / bz
        return loss
    
#     def compute_variance(self,x):
#         x = x - x.mean(dim=0)
# #         y = y - y.mean(dim=0)

#         std_x = torch.sqrt(x.var(dim=0) + 1e-4)
# #         std_y = torch.sqrt(y.var(dim=0) + 1e-4)
# #         std_loss = (torch.mean(F.relu(1 - std_x)) + torch.mean(F.relu(1 - std_y))) * var_weight
#         return std_x 

    
    def recursive_to_device(self, inp, device):
        if isinstance(inp, list):
            return [self.recursive_to_device(item, device) for item in inp]
        elif isinstance(inp, torch.Tensor):
            return inp.to(device)
        else:
            return inp
        
    def train_epoch(self, epoch):
        loss_meter    = AverageMeter()
        var_loss_meter  = AverageMeter()
        loss_byol_meter = AverageMeter()
            
        self.image_model.train()
#         self.text_projection.train()
        
        epoch_start_time = time.time() 
        
        for idx, batch in enumerate(self.train_loader):
            input_ids = batch['caption_ids'].to(self.device)
            attention_mask = batch['attention_mask'].to(self.device)
            imgs = batch['imgs']
            imgs = self.recursive_to_device(imgs, self.device)

            self.adjust_mm(self.steps)
            self.adjust_learning_rate(self.steps)
            self.steps += 1
            
            q1,q2, target_z1,target_z2 = self.image_model(imgs, self.mm)
            
            text_features   = self.text_encoder(input_ids=input_ids, attention_mask=attention_mask)
#             text_embeddings = self.text_projection(text_features)
            
            loss_byol = self.regression_loss(q1,target_z2)
            loss_byol += self.regression_loss(q2,target_z1)
            
            variance_I = torch.var(q1, dim=0)
            variance_I += torch.var(q2, dim=0)
            
            variance_T = torch.var(text_features, dim=0)            
            var_loss = F.mse_loss(variance_I, variance_T)
            
            loss = loss_byol + 10*var_loss

            self.optimizer.zero_grad()
            loss.backward()
            self.optimizer.step()
            loss_meter.update(loss.item(), imgs[0].size(0))
            var_loss_meter.update(var_loss.item(), imgs[0].size(0))
            loss_byol_meter.update(loss_byol.item(), imgs[0].size(0))

            # Print log info
            if self.steps % self.log_step == 0:
                lr = self.optimizer.param_groups[0]["lr"]
                mm = self.mm
                self.logger.info(f'Epoch: [{epoch}][{idx}/{len(self.train_loader)}]\t'
                                f'Step {self.steps}\t'
                                f'lr {round(self.optimizer.param_groups[0]["lr"], 5)}\t'
                                f'mm {round(self.mm, 5)}\t'
                                f'Loss {loss_meter.val:.4f} \t'
                                f'B_Loss {loss_byol_meter.val:.4f} \t'
                                f'V_Loss {var_loss_meter.val:.4f} \t'
                                )
                
        epoch_end_time = time.time()  # End time of current epoch
        epoch_training_time = (epoch_end_time - epoch_start_time)/60
        self.total_training_time += epoch_training_time
        self.logger.info(f"Epoch {epoch} training time: {epoch_training_time:.2f} minutes")
        if epoch == self.total_epochs +1:
            self.total_training_time_hours = self.total_training_time / 3600  
            self.logger.info(f"Total training time: {self.total_training_time_hours:.2f} hours")      

In [13]:
def run_task(config):
    trainer = Trainer(config)
    trainer.resume_model(model_path=None)
    start_epoch = trainer.start_epoch
    for epoch in range(start_epoch + 1, trainer.total_epochs + 1):
        trainer.train_epoch(epoch)
        trainer.save_checkpoint(epoch)

In [14]:
def main():
    # Load configuration
    config_file = "config1.yaml"
    with open(config_file, 'r') as f:
        config = yaml.safe_load(f)
    config['data_pct'] = 100    
    run_task(config)

if __name__ == "__main__":
    main()

INFO:root:init model!


227327 images have loaded for training
4959 images have loaded for validation


INFO:root:BYOL(
  (online_network): EncoderwithProjection(
    (encoder): Sequential(
      (0): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
      (1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (2): ReLU(inplace=True)
      (3): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
      (4): Sequential(
        (0): Bottleneck(
          (conv1): Conv2d(64, 64, kernel_size=(1, 1), stride=(1, 1), bias=False)
          (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
          (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (conv3): Conv2d(64, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)
          (bn3): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (relu): R

Some weights of the model checkpoint at emilyalsentzer/Bio_ClinicalBERT were not used when initializing BertModel: ['cls.predictions.decoder.weight', 'cls.seq_relationship.bias', 'cls.predictions.transform.dense.bias', 'cls.predictions.transform.LayerNorm.bias', 'cls.predictions.transform.dense.weight', 'cls.predictions.bias', 'cls.seq_relationship.weight', 'cls.predictions.transform.LayerNorm.weight']
- This IS expected if you are initializing BertModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
INFO:root:TextEncoder(
  (model): BertModel(
    (embeddings): BertEmbeddings(
      (word_embeddings): Embedding(28996, 768, padding_idx=0)


INFO:root:get optimizer!
INFO:root:--> No loaded checkpoint!
INFO:root:Epoch: [1][49/3551]	Step 50	lr 1e-05	mm 0.996	Loss 4.4693 	B_Loss 3.9897 	V_Loss 0.0480 	
INFO:root:Epoch: [1][99/3551]	Step 100	lr 2e-05	mm 0.996	Loss 4.4754 	B_Loss 3.9847 	V_Loss 0.0491 	
INFO:root:Epoch: [1][149/3551]	Step 150	lr 3e-05	mm 0.996	Loss 4.4607 	B_Loss 3.9877 	V_Loss 0.0473 	
INFO:root:Epoch: [1][199/3551]	Step 200	lr 3e-05	mm 0.996	Loss 4.4184 	B_Loss 3.9629 	V_Loss 0.0455 	
INFO:root:Epoch: [1][249/3551]	Step 250	lr 4e-05	mm 0.996	Loss 4.3931 	B_Loss 3.9424 	V_Loss 0.0451 	
INFO:root:Epoch: [1][299/3551]	Step 300	lr 5e-05	mm 0.996	Loss 4.3529 	B_Loss 3.8961 	V_Loss 0.0457 	
INFO:root:Epoch: [1][349/3551]	Step 350	lr 5e-05	mm 0.996	Loss 4.3279 	B_Loss 3.8616 	V_Loss 0.0466 	
INFO:root:Epoch: [1][399/3551]	Step 400	lr 6e-05	mm 0.996	Loss 4.2749 	B_Loss 3.8076 	V_Loss 0.0467 	
INFO:root:Epoch: [1][449/3551]	Step 450	lr 7e-05	mm 0.996	Loss 4.2013 	B_Loss 3.7675 	V_Loss 0.0434 	
INFO:root:Epoch: [1][499

AttributeError: 'Trainer' object has no attribute 'model'

In [None]:
# total_start_time = time.time()
# roc_auc_scores = []

# for epoch in range(num_epochs):
#     combined_model.train()
#     epoch_loss = 0
#     for batch in tqdm(train_loader):
#         input_ids = batch['caption_ids'].to(device)
#         attention_mask = batch['attention_mask'].to(device)
#         images = batch['imgs']
#         view_1, view_2 = images
#         view_1 = view_1.to(device)
#         view_2 = view_2.to(device)

#         optimizer.zero_grad()

#         # Ensure you are correctly unpacking the outputs from the combined_model forward pass
#         online_1, target_1, text_features_1 = combined_model(view_1, input_ids, attention_mask)
#         online_2, target_2, text_features_2 = combined_model(view_2, input_ids, attention_mask)

#         # Calculate BYOL losses
#         loss_byol = (negative_cosine_similarity(online_1, target_2) + negative_cosine_similarity(online_2, target_1)) / 2

        # Calculate VICReg variance losses
#         variance_I = vicreg_loss(online_1, online_2)
#         variance_T = vicreg_loss(text_features_1, text_features_2)
#         loss_vicreg = F.mse_loss(variance_I, variance_T)

#         # Combined loss
#         loss = (loss_byol + loss_vicreg) / 2
#         loss.backward()
#         optimizer.step()

#         epoch_loss += loss.item()

#     logging.info(f"Epoch [{epoch+1}/{num_epochs}], Loss: {epoch_loss/len(train_loader):.4f}")

# total_end_time = time.time()
# total_training_time = total_end_time - total_start_time
# logging.info(f"Total training time: {total_training_time:.2f seconds}")

# # Save the model checkpoint
# torch.save(combined_model.state_dict(), "combined_model.pth")

In [None]:
# # Training loop for the combined model
# total_start_time = time.time()
# roc_auc_scores = []

# for epoch in range(num_epochs):
#     combined_model.train()
#     epoch_loss = 0
#     for batch in tqdm(train_loader):
#         input_ids = batch['caption_ids'].to(device)
#         attention_mask = batch['attention_mask'].to(device)
#         images = batch['imgs']
#         view_1, view_2 = images
#         view_1 = view_1.to(device)
#         view_2 = view_2.to(device)
        
#         labels = batch['labels'].to(device).float().unsqueeze(1)  # Adjust label shape

#         optimizer.zero_grad()

#         # Forward pass for view_1
#         outputs_1, online_1, target_1, text_features_1 = combined_model(view_1, input_ids, attention_mask)
        
#         # Forward pass for view_2 (assuming you need both views for BYOL and VICReg)
#         outputs_2, online_2, target_2, text_features_2 = combined_model(view_2, input_ids, attention_mask)

#         # Calculate classification loss
#         classification_loss = classification_criterion(outputs_1, labels)

#         # Calculate BYOL losses (assuming negative_cosine_similarity is defined)
#         loss_byol = (negative_cosine_similarity(online_1, target_2) + negative_cosine_similarity(online_2, target_1)) / 2

#         # Calculate VICReg variance losses (assuming vicreg_loss is defined)
#         variance_I = vicreg_loss(online_1, online_2)
#         variance_T = vicreg_loss(text_features_1, text_features_2)
#         loss_vicreg = F.mse_loss(variance_I, variance_T)

#         # Combined loss
#         loss = (classification_loss + loss_byol + loss_vicreg) / 3
#         loss.backward()
#         optimizer.step()

#         epoch_loss += loss.item()

#     logging.info(f"Epoch [{epoch+1}/{num_epochs}], Loss: {epoch_loss/len(train_loader):.4f}")

# # Validation loop
# combined_model.eval()
# val_loss = 0
# val_labels = []
# val_outputs = []
# with torch.no_grad():
#     for batch in tqdm(valid_loader):
#         input_ids = batch['caption_ids']
#         attention_mask = batch['attention_mask']
#         images = batch['imgs']
#         labels = batch['labels']

#         view_1, view_2 = images
#         view_1 = view_1.to(device)
#         view_2 = view_2.to(device)
#         input_ids = input_ids.to(device)
#         attention_mask = attention_mask.to(device)
#         labels = labels.to(device)

#         q_1, t_1, text_features_1 = combined_model(view_1, input_ids, attention_mask)
#         q_2, t_2, text_features_2 = combined_model(view_2, input_ids, attention_mask)

#         # Calculate classification loss
#         outputs_1, _, _ = combined_model(view_1, input_ids, attention_mask)
#         classification_loss_1 = classification_criterion(outputs_1, labels)

#         outputs_2, _, _ = combined_model(view_2, input_ids, attention_mask)
#         classification_loss_2 = classification_criterion(outputs_2, labels)

#         classification_loss = (classification_loss_1 + classification_loss_2) / 2

#         # Calculate BYOL losses
#         loss_byol = (negative_cosine_similarity(q_1, t_2) + negative_cosine_similarity(q_2, t_1)) / 2

#         # Calculate VICReg variance losses
#         variance_I = vicreg_loss(q_1, q_2)
#         variance_T = vicreg_loss(text_features_1, text_features_2)
#         loss_vicreg = F.mse_loss(variance_I, variance_T)

#         # Combined loss
#         loss = (classification_loss + loss_byol + loss_vicreg) / 3

#         val_loss += loss.item()
#         val_labels.append(labels.cpu().numpy())
#         val_outputs.append(outputs_1.cpu().numpy())  # Use outputs from the first view for ROC AUC calculation

# # Calculate ROC AUC score
# val_labels = np.concatenate(val_labels)
# val_outputs = np.concatenate(val_outputs)
# roc_auc = roc_auc_score(val_labels, val_outputs)
# roc_auc_scores.append(roc_auc)

# logging.info(f"Epoch [{epoch+1}/{num_epochs}], Validation Loss: {val_loss/len(valid_loader):.4f}, ROC AUC: {roc_auc:.4f}")
# print(f"Epoch [{epoch+1}/{num_epochs}], Validation Loss: {val_loss/len(valid_loader):.4f}, ROC AUC: {roc_auc:.4f}")

# total_end_time = time.time()
# total_training_time = total_end_time - total_start_time
# logging.info(f"Total training time: {total_training_time:.2f} seconds")

# # Save the model checkpoint
# torch.save(combined_model.state_dict(), "combined_model.pth")

# # Plot the ROC AUC scores over epochs
# plt.figure()
# plt.plot(range(1, num_epochs+1), roc_auc_scores, marker='o')
# plt.xlabel('Epoch')
# plt.ylabel('ROC AUC Score')
# plt.title('ROC AUC Score over Epochs')
# plt.savefig('roc_auc_scores.png')
# plt.show()

In [None]:
# print(batch.keys()) 