In [None]:
import numpy as np
import pandas as pd
import os
import ast

In [None]:
# model_path = "/kaggle/input/convirt1/pytorch/default/1"
dataset_path = "/kaggle/input/mimic-cxr-dataset/official_data_iccv_final"

In [None]:
# !pip install gdown
# gdrive_link_1 = "https://drive.google.com/uc?id=1dincSb_q9LujRshYKspn0mloX3Zm0T_y"
# gdrive_link_2 = "https://drive.google.com/uc?id=1reDpwuNDeXn4Ww50qbozuQCMugf1Vo9a"
# !gdown {gdrive_link_1}
# !gdown {gdrive_link_2}

In [None]:
# model_url = "https://drive.google.com/drive/folders/1_55_Z8f1NYMapbyzwTS3eNkkmpkTmmtX?usp=sharing"
# !gdown --folder "{model_url}"

In [None]:
train_csv_path = "/kaggle/input/newdts2/convirt_train_sentence.csv"
validate_csv_path = "/kaggle/input/newdts2/convirt_val_sentence.csv"
train_df = pd.read_csv(train_csv_path)
val_df = pd.read_csv(validate_csv_path)

In [None]:
print("size of train data: ", train_df.shape)
print("size of validate data: ", val_df.shape)

## Config

In [None]:
import yaml
import os

# 1. ƒê·ªãnh nghƒ©a n·ªôi dung c·∫•u h√¨nh (Dictionary)
config_data = {
    # --- Training Hyperparameters ---
    "batch_size": 64,          # Batch size v·∫≠t l√Ω (t√πy thu·ªôc VRAM GPU, 16GB VRAM th√¨ 32 ok)
    "accumulation_steps": 1,   # <--- TH√äM M·ªöI: T√≠ch l≈©y 4 l·∫ßn -> Effective Batch Size = 32 * 4 = 128
        
    "start_epoch": 1,
    "epochs": 11,              # SimCLR h·ªôi t·ª• kh√° l√¢u, nh∆∞ng 200 c√≥ th·ªÉ h∆°i nhi·ªÅu cho test, m√¨nh ƒë·ªÉ 50-100 t√πy b·∫°n
    # "progressive_unfreezing_phase": 1, 
    "eval_every_n_epochs": 1,  # N√™n validate th∆∞·ªùng xuy√™n h∆°n ƒë·ªÉ check overfit s·ªõm
    "log_every_n_steps": 20,
    "weight_decay": 1e-4,      # 1e-3 c√≥ th·ªÉ h∆°i cao cho AdamW, th∆∞·ªùng d√πng 1e-4 ho·∫∑c 1e-6
    "fp16_precision": True,    # Mixed precision gi√∫p ti·∫øt ki·ªám VRAM -> tƒÉng ƒë∆∞·ª£c batch_size
    "truncation": True,

    # [NEW]: Tham s·ªë cho Gradient Clipping (ƒë·ªÉ tr√°nh b√πng n·ªï gradient)
    "max_grad_norm": 1.0,
    "patience": 3,# D·ª´ng n·∫øu sau 10 epoch loss validation kh√¥ng gi·∫£m
    
    # --- Split Learning Rates ---
    # V√¨ Effective Batch Size tƒÉng l√™n 128, ta c√≥ th·ªÉ gi·ªØ LR n√†y ho·∫∑c tƒÉng nh·∫π
    "learning_rate_resnet": 3e-4,
    "learning_rate_bert": 3e-5,
    "warmup_epochs": 1,
    
    # --- Checkpoint ---
    # QUAN TR·ªåNG: N·∫øu train m·ªõi ho√†n to√†n, h√£y ƒë·ªÉ None. 
    # N·∫øu ƒë·ªÉ ƒë∆∞·ªùng d·∫´n c≈© m√† file kh√¥ng t·ªìn t·∫°i, code train ƒë√£ c√≥ try-catch ƒë·ªÉ handle v·ªÅ train scratch.
    "fine_tune_from": "/kaggle/working/MIMIC-CXR_unfreeze_1_change_valid/runs/Dec13_07-11-09_34681e2a5120", 

    # --- Model Configuration ---
    "model": {
        "out_dim": 512,         # Projection head dimension
        "res_base_model": "resnet50",
        "bert_base_model": "emilyalsentzer/Bio_ClinicalBERT",
        "freeze_layers": [0, 1, 2, 3, 4, 5],
        "do_lower_case": False
    },

    # --- Train Configuration ---
    "train": {
        # "freeze_resnet": False,
        "unfreeze_resnet_block": 3,
        "unfreeze_bert_layer": 2,
        "use_loss": "sigmoid", # sigmoid or ntxent
        "trainable_t": True, # only for sigmoid loss
    },
    
    # --- Dataset Configuration ---
    "dataset": {
        "s": 1,
        "input_shape": "(224,224,3)",
        "num_workers": 4,       # Kaggle c√≥ 2-4 core CPU, ƒë·ªÉ 4 l√† ·ªïn
        "valid_size": 0.1,
        
        "train_csv_file": train_csv_path,
        "val_csv_file": validate_csv_path,
        "text_from_files": False,

        # ƒê·∫£m b·∫£o ƒë∆∞·ªùng d·∫´n n√†y ƒë√∫ng tr√™n m√¥i tr∆∞·ªùng c·ªßa b·∫°n
        "text_root_dir": dataset_path,
        "img_root_dir": dataset_path,
        
        "img_path_col": 1, 
        "text_col": 2      
    },

    # --- Loss Configuration ---
    "loss": {
        "temperature": 0.1,         # Quan tr·ªçng nh·∫•t trong NTXentLoss
        "use_cosine_similarity": True,
        "alpha_weight": 0.75
    }
}

# 2. L∆∞u Dictionary xu·ªëng file .yml
output_path = "config.yml" 

with open(output_path, 'w') as f:
    yaml.dump(config_data, f, default_flow_style=False, sort_keys=False)

print(f"‚úî ƒê√£ t·∫°o file c·∫•u h√¨nh m·ªõi t·∫°i: {os.path.abspath(output_path)}")
print(f"‚úî Effective Batch Size: {config_data['batch_size'] * config_data['accumulation_steps']}")


In [None]:
import torch
from torch import nn, optim
from torch.utils.data import DataLoader, Dataset, random_split
from torchvision import transforms, models
from transformers import AutoTokenizer, AutoModel
from typing import Tuple, List, Dict, Any


## Dataset Wrapper

In [None]:
from PIL import Image
import pandas as pd
import os
import random

class ClrDataset(Dataset):
    """Contrastive Learning Representations Dataset."""

    def __init__(self, 
                 csv_file, 
                 img_root_dir, 
                 input_shape, 
                 img_path_col, 
                 text_col, 
                 text_from_files, 
                 text_root_dir, 
                 transform=None,
                 tokenizer=None,       # <--- Th√™m tham s·ªë n√†y
                 max_length=512):      # <--- Th√™m tham s·ªë n√†y
        
        self.clr_frame = pd.read_csv(csv_file)
        self.img_root_dir = img_root_dir
        self.transform = transform
        self.input_shape = input_shape
        self.img_path_col = int(img_path_col)
        self.text_col = int(text_col)
        self.text_from_files = text_from_files
        self.text_root_dir = text_root_dir
        self.tokenizer = tokenizer     # <--- L∆∞u tokenizer
        self.max_length = max_length   # <--- L∆∞u max_length

    def __len__(self):
        return len(self.clr_frame)

    def __getitem__(self, idx):
        if torch.is_tensor(idx):
            idx = idx.tolist()

        # 1. Load Image
        img_name = os.path.join(self.img_root_dir,
                                self.clr_frame.iloc[idx, self.img_path_col])
        image = Image.open(img_name)
        if self.input_shape[2] == 3:
            image = image.convert('RGB')
        
        # 2. Load Text Phrase
        if not self.text_from_files:
            text = self.clr_frame.iloc[idx, self.text_col]
            # Simple cleaning
            text = str(text).replace("\n", "") 
            ls_text = text.split(".")
            # Filter empty strings
            ls_text = [t for t in ls_text if t.strip()]
            if not ls_text: # Handle case where text might be empty
                phrase = text
            else:
                phrase = random.choice(ls_text)
        else:
            text_path = os.path.join(self.text_root_dir, 
                                     self.clr_frame.iloc[idx, self.text_col])
            with open(text_path) as f:
                content = f.readlines()
            content = content[0].replace("\n", "")
            ls_text = content.split(".")
            ls_text = [t for t in ls_text if t.strip()]
            if not ls_text:
                phrase = content
            else:
                phrase = random.choice(ls_text)

        # 3. Tokenize ngay t·∫°i ƒë√¢y (QUAN TR·ªåNG)
        if self.tokenizer:
            # Tokenize tr·∫£ v·ªÅ dict g·ªìm input_ids, attention_mask, ...
            tokenized_output = self.tokenizer(
                phrase,
                max_length=self.max_length,
                padding='max_length',
                truncation=True,
                return_tensors='pt'
            )
            
            # Tokenizer m·∫∑c ƒë·ªãnh th√™m dimension batch (1, seq_len), ta c·∫ßn squeeze v·ªÅ (seq_len)
            phrase_input = {k: v.squeeze(0) for k, v in tokenized_output.items()}
            
            sample = {'image': image, 'phrase': phrase_input}
        else:
            sample = {'image': image, 'phrase': phrase}

        # 4. Transform Image
        if self.transform:
            # L∆∞u √Ω: SimCLRDataTransform c·∫ßn ƒë∆∞·ª£c s·ª≠a ƒë·ªÉ ch·ªâ transform image
            # Logic ·ªü d∆∞·ªõi gi·∫£ ƒë·ªãnh transform ƒë√£ ƒë∆∞·ª£c s·ª≠a
            sample = self.transform(sample)

        return sample


In [None]:
import cv2
np.random.seed(0)


class GaussianBlur(object):
    # Implements Gaussian blur as described in the SimCLR paper
    def __init__(self, kernel_size, min=0.1, max=2.0):
        self.min = min
        self.max = max
        # kernel size is set to be 10% of the image height/width
        self.kernel_size = kernel_size

    def __call__(self, sample):
        sample = np.array(sample)

        # blur the image with a 50% chance
        prob = np.random.random_sample()

        if prob < 0.5:
            sigma = (self.max - self.min) * np.random.random_sample() + self.min
            sample = cv2.GaussianBlur(sample, (self.kernel_size, self.kernel_size), sigma)

        return sample
import numpy as np
from torch.utils.data import DataLoader
from torch.utils.data.sampler import SubsetRandomSampler
import torchvision.transforms as transforms

class DataSetWrapper(object):
    def __init__(self, 
                batch_size, 
                num_workers, 
                valid_size, 
                input_shape, 
                s, 
                train_csv_file,
                val_csv_file,
                img_root_dir, 
                img_path_col, 
                text_col, 
                text_from_files, 
                text_root_dir,
                tokenizer=None): # <--- ƒê√£ th√™m tham s·ªë tokenizer
                
        self.batch_size = batch_size
        self.num_workers = num_workers
        self.valid_size = valid_size
        self.s = s
        self.input_shape = eval(input_shape)
        self.train_csv_file = train_csv_file
        self.val_csv_file = val_csv_file
        self.img_root_dir = img_root_dir
        self.img_path_col = img_path_col 
        self.text_col = text_col
        self.text_from_files = text_from_files
        self.text_root_dir = text_root_dir
        self.tokenizer = tokenizer # <--- L∆∞u tokenizer

    def get_data_loaders(self):
        data_augment = self._get_simclr_pipeline_transform()
        
        # Truy·ªÅn tokenizer v√†o ClrDataset
        train_dataset = ClrDataset(csv_file=self.train_csv_file,
                                    img_root_dir=self.img_root_dir,
                                    input_shape = self.input_shape,
                                    img_path_col = self.img_path_col, 
                                    text_col = self.text_col, 
                                    text_from_files = self.text_from_files, 
                                    text_root_dir = self.text_root_dir, 
                                    transform=SimCLRDataTransform(data_augment),
                                    tokenizer=self.tokenizer # <--- Truy·ªÅn xu·ªëng dataset
                                    )
        
        valid_dataset = ClrDataset(csv_file=self.val_csv_file,
                                    img_root_dir=self.img_root_dir,
                                    input_shape = self.input_shape,
                                    img_path_col = self.img_path_col, 
                                    text_col = self.text_col, 
                                    text_from_files = self.text_from_files, 
                                    text_root_dir = self.text_root_dir, 
                                    transform=SimCLRDataTransform(data_augment),
                                    tokenizer=self.tokenizer # <--- Truy·ªÅn xu·ªëng dataset
                                    )

        train_loader, valid_loader = self.get_train_validation_data_loaders(train_dataset, valid_dataset)
        return train_loader, valid_loader

    def _get_simclr_pipeline_transform(self):
        # get a set of data augmentation transformations
        data_transforms = transforms.Compose([
                                            transforms.Resize((self.input_shape[0], self.input_shape[1])),
                                            transforms.RandomResizedCrop(size=self.input_shape[0], scale=(0.8, 1.0)),
                                            transforms.RandomHorizontalFlip(),
                                            transforms.RandomGrayscale(p=0.2),
                                            transforms.ToTensor(),
                                            transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)),
                                            ])
        return data_transforms

    def get_train_validation_data_loaders(self, train_dataset, valid_dataset):
        # obtain training indices that will be used for validation
        # num_train = len(train_dataset)
        # indices = list(range(num_train))
        # np.random.shuffle(indices)

        # split = int(np.floor(self.valid_size * num_train))
        # train_idx, valid_idx = indices[split:], indices[:split]

        # # define samplers for obtaining training and validation batches
        # train_sampler = SubsetRandomSampler(train_idx)
        # valid_sampler = SubsetRandomSampler(valid_idx)

        # T·ªëi ∆∞u DataLoader v·ªõi pin_memory=True
        # train_loader = DataLoader(train_dataset, batch_size=self.batch_size, sampler=train_sampler,
        #                           num_workers=self.num_workers, drop_last=True, shuffle=False,
        #                           pin_memory=True)

        # valid_loader = DataLoader(valid_dataset, batch_size=self.batch_size, sampler=valid_sampler,
        #                           num_workers=self.num_workers, drop_last=True,
        #                           pin_memory=True)

        train_loader = DataLoader(train_dataset, batch_size=self.batch_size,
                                  num_workers=self.num_workers, drop_last=True, shuffle=False,
                                  pin_memory=True)

        valid_loader = DataLoader(valid_dataset, batch_size=self.batch_size,
                                  num_workers=self.num_workers, drop_last=True,
                                  pin_memory=True)
        
        return train_loader, valid_loader


class SimCLRDataTransform(object):
    def __init__(self, transform_image):
        self.transform_image = transform_image

    def __call__(self, sample):
        xi = self.transform_image(sample['image'])
        xl = sample['phrase']

        return xi, xl


## Loss function

In [None]:
import torch.nn.functional as F

class SigmoidContrastiveLoss(nn.Module):
    def __init__(self, temperature=0.5, trainable_t = True):
        super().__init__()
        self.trainable_t = trainable_t
        if trainable_t:
            self.log_temperature = nn.Parameter(torch.zeros(1))
            self.bias = nn.Parameter(torch.zeros(1))
        else:
            self.temperature = temperature
        # self.margin = margin
        # self.alpha = alpha
        self.cumulative_mean_diag = 0
        self.cumulative_top1_acc = 0
        self.cumulative_mean_off = 0

    def forward(self, z_img, z_txt):
        z_img = F.normalize(z_img, dim=1)
        z_txt = F.normalize(z_txt, dim=1)

        logits = z_img @ z_txt.T
        if self.trainable_t:
            t = torch.exp(self.log_temperature)
            logits = logits * t + self.bias
        else:
            logits = logits/self.temperature
        
        B = logits.size(0)

        labels = torch.eye(B, device=logits.device) * 2 - 1
        log_sigmoid = F.logsigmoid(labels * logits)
        weighted_matrix = ((B-1)*torch.eye(B, device=logits.device) + 1)
        weighted_Loss = -log_sigmoid * weighted_matrix
        sigmoid_loss = weighted_Loss.mean()

        diag = logits.diag()
        # print("mean diag: ", diag.mean())
        self.cumulative_mean_diag += diag.mean()
        top1 = (logits.argmax(dim=1) == torch.arange(B, device=logits.device)).float().mean()
        # print("top1 acc: ", top1)
        self.cumulative_top1_acc += top1
        off  = logits[~torch.eye(B, dtype=bool, device=logits.device)]
        self.cumulative_mean_off += off.mean()

        return sigmoid_loss

        # pos_constraint = F.relu(-logits.diag()).mean()
        # return sigmoid_loss + 0.5 * pos_constraint


class SigLipLoss(nn.Module):
    """ Sigmoid Loss for Language Image Pre-Training (SigLIP) - https://arxiv.org/abs/2303.15343

    @article{zhai2023sigmoid,
      title={Sigmoid loss for language image pre-training},
      author={Zhai, Xiaohua and Mustafa, Basil and Kolesnikov, Alexander and Beyer, Lucas},
      journal={arXiv preprint arXiv:2303.15343},
      year={2023}
    }
    """
    def __init__(
            self,
            cache_labels=False,
            rank=0,
            world_size=1,
            bidir=True,
            use_horovod=False,
    ):
        super().__init__()
        self.cache_labels = cache_labels
        self.rank = rank
        self.world_size = world_size
        assert not use_horovod  # FIXME need to look at hvd ops for ring transfers
        self.use_horovod = use_horovod
        self.bidir = bidir

        # cache state FIXME cache not currently used, worthwhile?
        self.prev_num_logits = 0
        self.labels = {}

        self.cumulative_mean_diag = 0
        self.cumulative_top1_acc = 0
        self.cumulative_mean_off = 0

    def get_ground_truth(self, device, dtype, num_logits, negative_only=False) -> torch.Tensor:
        labels = -torch.ones((num_logits, num_logits), device=device, dtype=dtype)
        if not negative_only:
            labels = 2 * torch.eye(num_logits, device=device, dtype=dtype) + labels
        return labels

    def get_logits(self, image_features, text_features, logit_scale, logit_bias=None):
        logits = logit_scale * image_features @ text_features.T
        if logit_bias is not None:
            logits += logit_bias
        return logits

    
    def _loss(self, image_features, text_features, logit_scale, logit_bias=None, negative_only=False):
        logits = self.get_logits(image_features, text_features, logit_scale, logit_bias)
        labels = self.get_ground_truth(
            image_features.device,
            image_features.dtype,
            image_features.shape[0],
            negative_only=negative_only,
        )
        
        loss = -F.logsigmoid(labels * logits).sum() / image_features.shape[0]

        B = image_features.shape[0]
        diag = logits.diag()
        # print("mean diag: ", diag.mean())
        self.cumulative_mean_diag += diag.mean()
        top1 = (logits.argmax(dim=1) == torch.arange(B, device=logits.device)).float().mean()
        # print("top1 acc: ", top1)
        self.cumulative_top1_acc += top1
        off  = logits[~torch.eye(B, dtype=bool, device=logits.device)]
        self.cumulative_mean_off += off.mean()
            
        return loss

    def forward(self, image_features, text_features, logit_scale, logit_bias, output_dict=False):
        image_features = F.normalize(image_features, dim=-1)
        text_features  = F.normalize(text_features, dim=-1)

        loss = self._loss(image_features, text_features, logit_scale, logit_bias)

        if self.world_size > 1:
            # exchange text features w/ neighbour world_size - 1 times
            right_rank = (self.rank + 1) % self.world_size
            left_rank = (self.rank - 1 + self.world_size) % self.world_size
            if self.bidir:
                text_features_to_right = text_features_to_left = text_features
                num_bidir, remainder = divmod(self.world_size - 1, 2)
                for i in range(num_bidir):
                    text_features_recv = neighbour_exchange_bidir_with_grad(
                        left_rank,
                        right_rank,
                        text_features_to_left,
                        text_features_to_right,
                    )

                    for f in text_features_recv:
                        loss += self._loss(
                            image_features,
                            f,
                            logit_scale,
                            logit_bias,
                            negative_only=True,
                        )
                    text_features_to_left, text_features_to_right = text_features_recv

                if remainder:
                    text_features_recv = neighbour_exchange_with_grad(
                        left_rank, right_rank, text_features_to_right)

                    loss += self._loss(
                        image_features,
                        text_features_recv,
                        logit_scale,
                        logit_bias,
                        negative_only=True,
                    )
            else:
                text_features_to_right = text_features
                for i in range(self.world_size - 1):
                    text_features_from_left = neighbour_exchange_with_grad(
                        left_rank, right_rank, text_features_to_right)

                    loss += self._loss(
                        image_features,
                        text_features_from_left,
                        logit_scale,
                        logit_bias,
                        negative_only=True,
                    )
                    text_features_to_right = text_features_from_left

        return {"contrastive_loss": loss} if output_dict else loss


class NTXentLoss(torch.nn.Module):

    def __init__(self, device, batch_size, temperature, use_cosine_similarity, alpha_weight):
        super(NTXentLoss, self).__init__()
        self.batch_size = batch_size
        self.temperature = temperature
        self.alpha_weight = alpha_weight
        self.device = device
        self.softmax = torch.nn.Softmax(dim=-1)
        self.criterion = torch.nn.CrossEntropyLoss(reduction="sum")
        self.cumulative_top1_acc = 0 
        self.cumulative_mean_diag = 0
        self.cumulative_mean_off = 0

    def softXEnt(self, target, logits):
        """ 
        From the pytorch discussion Forum:
        https://discuss.pytorch.org/t/soft-cross-entropy-loss-tf-has-it-does-pytorch-have-it/69501 
        """
        logprobs = torch.nn.functional.log_softmax(logits, dim = 1)
        loss = -(target * logprobs).sum() / logits.shape[0]
        return loss

    def forward(self, zis, zjs,
                    norm=True,
                    weights=1.0):
        temperature = self.temperature
        alpha = self.alpha_weight

        # Get (normalized) hidden1 and hidden2.
        if norm:
            zis = F.normalize(zis, p=2, dim=1)
            zjs = F.normalize(zjs, p=2, dim=1)
            
        hidden1, hidden2 = zis, zjs
        B = hidden1.shape[0] # batch size

        hidden1_large = hidden1
        hidden2_large = hidden2
        labels = F.one_hot(torch.arange(start=0, end=B, dtype=torch.int64), num_classes=B).float()
        labels = labels.to(self.device)
        masks = F.one_hot(torch.arange(start=0, end=B, dtype=torch.int64), num_classes=B)
        
        logits_ab = torch.matmul(hidden1, torch.transpose(hidden2_large,0, 1)) / temperature
        logits_ba = torch.matmul(hidden2, torch.transpose(hidden1_large,0, 1)) / temperature

        # get debug metric
        diag = logits_ab.diag()
        # print("mean diag: ", diag.mean())
        self.cumulative_mean_diag += diag.mean()
        top1 = (logits_ba.argmax(dim=1) == torch.arange(B, device=logits_ba.device)).float().mean()
        # print("top1 acc: ", top1)
        self.cumulative_top1_acc += top1
        off  = logits_ab[~torch.eye(B, dtype=bool, device=logits_ab.device)]
        self.cumulative_mean_off += off.mean()

        loss_a = self.softXEnt(labels, logits_ab)
        loss_b = self.softXEnt(labels, logits_ba)

        return alpha*loss_a + (1-alpha)*loss_b


## Model

In [None]:
"""
Reference for BERT Sentence Embeddings method

@inproceedings{reimers-2019-sentence-bert,
    title = "Sentence-BERT: Sentence Embeddings using Siamese BERT-Networks",
    author = "Reimers, Nils and Gurevych, Iryna",
    booktitle = "Proceedings of the 2019 Conference on Empirical Methods in Natural Language Processing",
    month = "11",
    year = "2019",
    publisher = "Association for Computational Linguistics",
    url = "http://arxiv.org/abs/1908.10084",

"""

import torchvision.models as models
from transformers import AutoModel

# Create the BertClassfier class
class ModelCLR(nn.Module):
    def __init__(self, res_base_model, bert_base_model, out_dim, freeze_layers, do_lower_case):
        super(ModelCLR, self).__init__()
        # BERT base
        self.bert_model = self._get_bert_basemodel(bert_base_model, freeze_layers)
        # projection MLP for BERT
        # self.bert_l1 = nn.Linear(768, 768)
        # self.bert_l2 = nn.Linear(768, out_dim)
        self.bert_proj = nn.Sequential(
            nn.Linear(768, 1024),
            nn.BatchNorm1d(1024),
            nn.ReLU(),
            nn.Dropout(0.2),
            nn.Linear(1024, out_dim)
        )


        # ResNet base (store the original resnet to access layer names)
        self.resnet_dict = {
            "resnet18": models.resnet18(weights=models.ResNet18_Weights.IMAGENET1K_V1),
            "resnet50": models.resnet50(weights=models.ResNet50_Weights.IMAGENET1K_V1)
        }
        resnet = self._get_res_basemodel(res_base_model)
        self.resnet = resnet  # keep original model

        num_ftrs = resnet.fc.in_features
        self.res_features = nn.Sequential(*list(resnet.children())[:-1])
        # projection MLP for ResNet Model
        # self.res_l1 = nn.Linear(num_ftrs, 768)
        # self.res_l2 = nn.Linear(768, out_dim)
        self.res_proj = nn.Sequential(
            nn.Linear(num_ftrs, 1024),
            nn.BatchNorm1d(1024),
            nn.ReLU(),
            nn.Dropout(0.2),
            nn.Linear(1024, out_dim)
        )


        # --- Freeze full backbone by default ---
        # Freeze all ResNet parameters
        for param in self.resnet.parameters():
            param.requires_grad = False
        # Set batchnorm layers in resnet to eval mode to avoid updating running stats
        for m in self.resnet.modules():
            if isinstance(m, nn.BatchNorm2d):
                m.eval()

        # Freeze full BERT encoder optionally (we may still want to train bert_l1/bert_l2)
        for param in self.bert_model.parameters():
            param.requires_grad = False

    def _get_res_basemodel(self, res_model_name):
        try:
            res_model = self.resnet_dict[res_model_name]
            print("Image feature extractor:", res_model_name)
            return res_model
        except:
            raise ("Invalid model name. Check the config file and pass one of: resnet18 or resnet50")

    def _get_bert_basemodel(self, bert_model_name, freeze_layers):
        try:
            model = AutoModel.from_pretrained(bert_model_name)#, return_dict=True)
            print("Image feature extractor:", bert_model_name)
        except:
            raise ("Invalid model name. Check the config file and pass a BERT model from transformers lybrary")

        if freeze_layers is not None:
            for layer_idx in freeze_layers:
                for param in list(model.encoder.layer[layer_idx].parameters()):
                    param.requires_grad = False
        return model

    def unfreeze_resnet_last_n_blocks(self, n=1):
        """
        Unfreeze last `n` residual layers (layer4, layer3, ...).
        n=1 => layer4; n=2 => layer4+layer3
        """
        layers = []
        # ResNet layers are named layer1, layer2, layer3, layer4
        for i in range(4, 0, -1):  # 4,3,2,1
            layers.append(f"layer{i}")
        for name in layers[:n]:
            layer = getattr(self.resnet, name)
            for param in layer.parameters():
                param.requires_grad = True
            for m in layer.modules():
                if isinstance(m, nn.BatchNorm2d):
                    m.train()

    def unfreeze_bert_last_n_layers(self, n=1):
        """
        Unfreeze the last n encoder layers of BERT (`self.bert_model.encoder.layer`).
        n=1 => last layer only.
        """
        encoder = self.bert_model.encoder
        total = len(encoder.layer)
        print("total number of bert layer: ", total)
        if n > total:
            n = total
        for i in range(total - n, total):
            for param in encoder.layer[i].parameters():
                param.requires_grad = True
        # If you unfreeze BERT layers, keep bert_model in train() mode when training

    
    def mean_pooling(self, model_output, attention_mask):
        """
        Mean Pooling - Take attention mask into account for correct averaging
        Reference: https://www.sbert.net/docs/usage/computing_sentence_embeddings.html
        """
        token_embeddings = model_output[0] #First element of model_output contains all token embeddings
        input_mask_expanded = attention_mask.unsqueeze(-1).expand(token_embeddings.size()).float()
        sum_embeddings = torch.sum(token_embeddings * input_mask_expanded, 1)
        sum_mask = torch.clamp(input_mask_expanded.sum(1), min=1e-9)
        return sum_embeddings / sum_mask

    def image_encoder(self, xis):
        h = self.res_features(xis)   # [B, C, 1, 1]
        h = torch.flatten(h, 1)      # [B, C]
    
        # x = self.res_l1(h)
        # x = F.relu(x)
        # x = self.res_l2(x)
        x = self.res_proj(h)
        return h, x


    def text_encoder(self, encoded_inputs):
        outputs = self.bert_model(**encoded_inputs)
        sentence_embeddings = self.mean_pooling(outputs, encoded_inputs['attention_mask'])
        # x = self.bert_l1(sentence_embeddings)
        # x = F.relu(x)
        # out_emb = self.bert_l2(x)
        out_emb = self.bert_proj(sentence_embeddings)
        return out_emb

    

    def forward(self, xis, encoded_inputs):

        h, zis = self.image_encoder(xis)

        zls = self.text_encoder(encoded_inputs)

        return zis, zls


In [None]:
import logging
import shutil
import os
import yaml
import numpy as np
import torch
import torch.nn.functional as F
from torch.cuda.amp import GradScaler, autocast
from torch.utils.tensorboard import SummaryWriter
from tqdm import tqdm
from time import time

# Import Transformers
from transformers import AutoTokenizer, get_cosine_schedule_with_warmup

# --- C·∫§U H√åNH LOGGING ---
logging.getLogger("transformers.tokenization_utils_base").setLevel(logging.ERROR)
os.environ["TOKENIZERS_PARALLELISM"] = "false"
torch.manual_seed(0)

# --- HELPER FUNCTION ---
def _save_config_file(model_checkpoints_folder):
    if not os.path.exists(model_checkpoints_folder):
        os.makedirs(model_checkpoints_folder)
        if os.path.exists("/kaggle/working/config.yml"):
            shutil.copy("/kaggle/working/config.yml", os.path.join(model_checkpoints_folder, "config.yml"))

def count_trainable_params(model):
    total = sum(p.numel() for p in model.parameters() if p.requires_grad)
    print("Trainable params:", total)
    # for name, p in model.named_parameters():
    #     if p.requires_grad:
    #         print(name, p.shape)


# ==========================================
# 1. CH√àN CLASS EARLY STOPPING T·∫†I ƒê√ÇY
# ==========================================
class EarlyStopping:
    """D·ª´ng training s·ªõm n·∫øu validation loss kh√¥ng c·∫£i thi·ªán sau m·ªôt s·ªë epoch."""
    def __init__(self, patience=7, verbose=False, delta=0, path='checkpoint.pt', trace_func=print):
        self.patience = patience
        self.verbose = verbose
        self.counter = 0
        self.best_score = None
        self.early_stop = False
        self.val_loss_min = np.Inf
        self.delta = delta
        self.path = path
        self.trace_func = trace_func

    def __call__(self, val_loss, model):
        score = -val_loss

        if self.best_score is None:
            self.best_score = score
            self.save_checkpoint(val_loss, model)
        elif score < self.best_score + self.delta:
            self.counter += 1
            self.trace_func(f'EarlyStopping counter: {self.counter} out of {self.patience}')
            if self.counter >= self.patience:
                self.early_stop = True
        else:
            self.best_score = score
            self.save_checkpoint(val_loss, model)
            self.counter = 0

    def save_checkpoint(self, val_loss, model):
        '''L∆∞u model khi validation loss gi·∫£m.'''
        if self.verbose:
            self.trace_func(f'Validation loss decreased ({self.val_loss_min:.6f} --> {val_loss:.6f}).  Saving model ...')
        torch.save(model.state_dict(), self.path)
        self.val_loss_min = val_loss

# ==========================================
# 2. CLASS CH√çNH (SIMCLR)
# ==========================================
class SimCLR(object):
    def __init__(self, dataset, config):
        self.config = config
        self.device = self._get_device()
        self.writer = SummaryWriter()
        self.dataset = dataset
        if config["train"]["use_loss"] == "ntxent":
            self.loss_function = NTXentLoss(
                self.device, config["batch_size"], **config["loss"]
            )
            print("simCLR using ntxent loss")
        elif config["train"]["use_loss"] == "sigmoid":
            self.loss_function = SigmoidContrastiveLoss(
                config["loss"]["temperature"], config["train"]["trainable_t"]
            ).to(self.device)
            # self.loss_function = SigLipLoss()
            # self.log_temperature = nn.Parameter(torch.zeros(1)).to(self.device)
            # self.bias = nn.Parameter(torch.zeros(1)).to(self.device)
            print("simCLR using Sigmoid loss")
        self.unfreeze_nesnet_block = config["train"]["unfreeze_resnet_block"]
        self.unfreeze_bert_layer = config["train"]["unfreeze_bert_layer"]
        self.truncation = config["truncation"]

    def _get_device(self):
        device = "cuda" if torch.cuda.is_available() else "cpu"
        print("Running on:", device)
        return device

    @staticmethod
    def build_optimizer_for_finetune(model, config):
        # 1. L·∫•y tham s·ªë c·ªßa Backbone (ResNet v√† BERT)
        # L∆∞u √Ω: Ki·ªÉm tra xem model c·ªßa b·∫°n d√πng t√™n 'resnet' hay 'visual_encoder'
        # N·∫øu code d∆∞·ªõi b√°o l·ªói ti·∫øp ·ªü d√≤ng resnet, h√£y ƒë·ªïi th√†nh model.visual_encoder
        try:
            resnet_params = list(model.resnet.parameters())
        except AttributeError:
            # Fallback n·∫øu t√™n bi·∫øn kh√°c
            resnet_params = list(model.visual_encoder.parameters()) if hasattr(model, 'visual_encoder') else []

        try:
            bert_params = list(model.bert_model.parameters())
        except AttributeError:
            bert_params = list(model.text_encoder.parameters()) if hasattr(model, 'text_encoder') else []
        
        # 2. T√¨m tham s·ªë c·ªßa Head b·∫±ng c√°ch lo·∫°i tr·ª´ Backbone ra kh·ªèi to√†n b·ªô Model
        # (C√°ch n√†y ch·∫°y ƒë√∫ng b·∫•t k·ªÉ b·∫°n ƒë·∫∑t t√™n Head l√† res_l1, projector, hay classifier)
        backbone_param_ids = set(map(id, resnet_params)) | set(map(id, bert_params))
        
        head_params = [p for p in model.parameters() if id(p) not in backbone_param_ids and p.requires_grad]
        
        # L·ªçc l·∫°i backbone ch·ªâ l·∫•y nh·ªØng layer kh√¥ng b·ªã freeze (requires_grad=True)
        resnet_params = [p for p in resnet_params if p.requires_grad]
        bert_params = [p for p in bert_params if p.requires_grad]

        # 3. Thi·∫øt l·∫≠p Learning Rate
        head_lr = config["learning_rate_resnet"] # Head th∆∞·ªùng train c√πng t·ªëc ƒë·ªô v·ªõi ResNet
        res_lr = config["learning_rate_resnet"]
        bert_lr = config["learning_rate_bert"]
        weight_decay = config["weight_decay"]
        
        print(f"‚úî Optimizer setup: Head ({len(head_params)} params), ResNet ({len(resnet_params)} params), BERT ({len(bert_params)} params)")
        
        param_groups = []
        if len(head_params) > 0:
            param_groups.append({'params': head_params, 'lr': head_lr, 'weight_decay': weight_decay})
        if len(resnet_params) > 0:
            param_groups.append({'params': resnet_params, 'lr': res_lr, 'weight_decay': weight_decay})
        if len(bert_params) > 0:
            param_groups.append({'params': bert_params, 'lr': bert_lr, 'weight_decay': weight_decay})

        optimizer = torch.optim.AdamW(param_groups)
        return optimizer

    def train(self):
        train_loader, valid_loader = self.dataset.get_data_loaders()

        model = ModelCLR(**self.config["model"]).to(self.device)
        model = self._load_pre_trained_weights(model)

        if self.unfreeze_nesnet_block:
            model.unfreeze_resnet_last_n_blocks(n=self.unfreeze_nesnet_block)
        if self.unfreeze_bert_layer:
            model.unfreeze_bert_last_n_layers(n=self.unfreeze_bert_layer)

        count_trainable_params(model)

        optimizer = self.build_optimizer_for_finetune(model, self.config)
        
        accumulation_steps = self.config.get("accumulation_steps", 1) 
        max_grad_norm = self.config.get("max_grad_norm", 1.0) 

        num_update_steps_per_epoch = len(train_loader) // accumulation_steps
        if len(train_loader) % accumulation_steps != 0:
            num_update_steps_per_epoch += 1

        total_steps = num_update_steps_per_epoch * self.config["epochs"]
        warmup_epochs = self.config.get("warmup_epochs", 1)
        warmup_steps = num_update_steps_per_epoch * warmup_epochs
        
        scheduler = get_cosine_schedule_with_warmup(
            optimizer, 
            num_warmup_steps=warmup_steps, 
            num_training_steps=total_steps
        )

        scaler = GradScaler()
        model_checkpoints_folder = os.path.join(self.writer.log_dir, "checkpoints")
        _save_config_file(model_checkpoints_folder)

        # --- KH·ªûI T·∫†O EARLY STOPPING ---
        patience = self.config.get("patience", 10) 
        save_path = os.path.join(model_checkpoints_folder, "model.pth")
        
        early_stopping = EarlyStopping(
            patience=patience, 
            verbose=True, 
            path=save_path
        )

        n_iter = 0
        valid_n_iter = 0
        
        print(f"Training with Differential LR: BERT={self.config.get('learning_rate_bert')}, ResNet={self.config.get('learning_rate_resnet')}")
        print(f"Accumulation steps: {accumulation_steps}")
        print(f"Early Stopping Patience: {patience}")

        for epoch_counter in range(self.config["start_epoch"], self.config["epochs"]):
            epoch_loss = 0.0
            num_batches = 0
            optimizer.zero_grad()

            for batch_idx, (xis, xls) in enumerate(tqdm(train_loader, desc=f"Epoch {epoch_counter}")):
                # st = time()
                xis = xis.to(self.device)
                # print("time to transfer xis to GPU: ", time() - st)
                # st = time()
                input_ids = xls['input_ids'].to(self.device)
                # print("time to transfer xls to GPU: ", time() - st)
                # st = time()
                attention_mask = xls['attention_mask'].to(self.device)
                # print("time to mask attention: ", time() - st)
                # st = time()
                encoded_inputs = {'input_ids': input_ids, 'attention_mask': attention_mask}
                if 'token_type_ids' in xls:
                     encoded_inputs['token_type_ids'] = xls['token_type_ids'].to(self.device)
                # print("time to encode input: ", time() - st)
                # st = time()
                with autocast():
                    zis, zls = model(xis, encoded_inputs)
                    # logit_scale = self.log_temperature.exp()
                    # loss = self.loss_function(zis, zls, logit_scale, self.bias)
                    loss = self.loss_function(zis, zls)

                    # loss = loss / accumulation_steps

                scaler.scale(loss).backward()

                is_last_batch = (batch_idx + 1) == len(train_loader)
                should_update = ((batch_idx + 1) % accumulation_steps == 0) or is_last_batch

                if should_update:
                    scaler.unscale_(optimizer)
                    torch.nn.utils.clip_grad_norm_(model.parameters(), max_grad_norm)

                    scaler.step(optimizer)
                    scaler.update()
                    scheduler.step()
                    optimizer.zero_grad()
                    
                    current_loss = loss.item() * accumulation_steps
                    if n_iter % self.config["log_every_n_steps"] == 0:
                        self.writer.add_scalar("train_loss", current_loss, global_step=n_iter)
                        self.writer.add_scalar("lr_head", optimizer.param_groups[0]['lr'], global_step=n_iter)
                    n_iter += 1

                epoch_loss += loss.item() * accumulation_steps
                num_batches += 1            

            retrival_acc = self.loss_function.cumulative_top1_acc/num_batches 
            print("top1 retrival acc: ", retrival_acc)
            self.loss_function.cumulative_top1_acc = 0
            self.writer.add_scalar("top1_retrival_acc", retrival_acc, global_step=epoch_counter)

            mean_logits_diag = self.loss_function.cumulative_mean_diag/num_batches
            print("mean diag: ",mean_logits_diag)
            self.writer.add_scalar("mean_logits_diag", mean_logits_diag, global_step=epoch_counter)
            self.loss_function.cumulative_mean_diag = 0

            mean_logits_off = self.loss_function.cumulative_mean_off/num_batches
            print("mean off: ",mean_logits_off)
            self.writer.add_scalar("mean_logits_off", mean_logits_off, global_step=epoch_counter)
            self.loss_function.cumulative_mean_off = 0
            
            cosine_sim = F.cosine_similarity(zis,zls).mean().item()
            print("cosine:", cosine_sim)
            self.writer.add_scalar("cosine", cosine_sim, global_step=epoch_counter)

            epoch_mean_loss = epoch_loss / num_batches
            print(f"Epoch {epoch_counter} ------ Train Loss: {epoch_mean_loss:.4f}")
            self.writer.add_scalar("epoch_train_loss", epoch_mean_loss, global_step=epoch_counter)

            # --- VALIDATION & EARLY STOPPING ---
            if epoch_counter % self.config["eval_every_n_epochs"] == 0:
                valid_loss = self._validate(model, valid_loader)
                print(f"Validation {epoch_counter} - Valid Loss: {valid_loss:.4f}")
                self.writer.add_scalar("validation_loss", valid_loss, global_step=valid_n_iter)
                valid_n_iter += 1
                
                # G·ªçi Early Stopping (n√≥ t·ª± l∆∞u model n·∫øu t·ªët h∆°n)
                early_stopping(valid_loss, model)
                
                if early_stopping.early_stop:
                    print("üöÄ  Early stopping triggered! Training stopped.")
                    break

    def _load_pre_trained_weights(self, model):
        try:
            checkpoints_folder = os.path.join("./runs", self.config["fine_tune_from"], "checkpoints")
            model_path = os.path.join(checkpoints_folder, "model.pth")
            if os.path.exists(model_path):
                state_dict = torch.load(model_path, map_location=self.device)
                model.load_state_dict(state_dict)
                print("Loaded pre-trained model with success.")
            else:
                print("Pre-trained weights file not found. Training from scratch.")
        except Exception as e:
            print(f"Exception loading weights: {e}. Training from scratch.")
        return model

    def _validate(self, model, valid_loader):
        with torch.no_grad():
            model.eval()
            valid_loss = 0.0
            counter = 0
            for xis, xls in tqdm(valid_loader, desc="Validating"):
                xis = xis.to(self.device)
                input_ids = xls['input_ids'].to(self.device)
                attention_mask = xls['attention_mask'].to(self.device)
                encoded_inputs = {'input_ids': input_ids, 'attention_mask': attention_mask}
                if 'token_type_ids' in xls:
                     encoded_inputs['token_type_ids'] = xls['token_type_ids'].to(self.device)

                zis, zls = model(xis, encoded_inputs)
                # logit_scale = self.log_temperature.exp()
                # loss = self.loss_function(zis, zls, logit_scale, self.bias)
                loss = self.loss_function(zis, zls)
                
                valid_loss += loss.item()
                counter += 1
            
            if counter > 0:
                valid_loss /= counter
                
            cosine_sim = F.cosine_similarity(zis,zls).mean().item()
            print("cosine:", cosine_sim)
            
        model.train()
        return valid_loss


In [None]:
def main():
    config = yaml.load(open("/kaggle/working/config.yml", "r"), Loader=yaml.FullLoader)
    
    # Init tokenizer ·ªü ƒë√¢y
    tokenizer = AutoTokenizer.from_pretrained(config["model"]["bert_base_model"])
    
    # Truy·ªÅn tokenizer v√†o DataSetWrapper
    dataset = DataSetWrapper(config['batch_size'], **config['dataset'], tokenizer=tokenizer)

    simclr = SimCLR(dataset, config)
    simclr.train()

if __name__ == "__main__":
    main()
