In [32]:
import os
import torch
import torch.nn as nn
import torch.nn.functional as F
import pandas as pd
from PIL import Image
from torch.utils.data import Dataset, DataLoader
import pytorch_lightning as pl
from pytorch_lightning.callbacks import TQDMProgressBar, Callback
from torchvision import transforms
from torchvision.models import resnet18
from time import time

In [33]:
import torch
print(f"CUDA available: {torch.cuda.is_available()}")
print(f"Current device: {torch.cuda.current_device()}")
print(f"Device name: {torch.cuda.get_device_name(0)}")

CUDA available: True
Current device: 0
Device name: NVIDIA GeForce RTX 3050 Ti Laptop GPU


In [34]:
# 1) Dataset với sanity-check
# -----------------------------------
class SimCLRDataset(Dataset):
    def __init__(self, csv_file, transform):
        self.data = pd.read_csv(csv_file)
        assert not self.data.empty, f"CSV {csv_file} is empty!"
        
        # Chuyển đổi tất cả đường dẫn trong cột 'file' thành đường dẫn tuyệt đối
        self.data['file'] = self.data['file'].apply(os.path.abspath)
        
        # Kiểm tra sự tồn tại của 5 file đầu tiên (tùy chọn)
        for p in self.data['file'].iloc[:5]:
            assert os.path.isfile(p), f"File not found: {p}"
        
        print(f"✅ Loaded {len(self.data)} samples from {csv_file}")
        self.transform = transform

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        img_path = self.data.iloc[idx]['file']
        image = Image.open(img_path).convert('RGB')
        xi = self.transform(image)
        xj = self.transform(image)
        return xi, xj

In [35]:
# 2) SimCLR LightningModule
# -----------------------------------
class SimCLR(pl.LightningModule):
    def __init__(self, temperature=0.5, lr=1e-3):
        super().__init__()
        torch.set_float32_matmul_precision('medium')
        self.temperature = temperature
        self.lr = lr

        # Encoder
        backbone = resnet18(weights=None)
        num_ftrs = backbone.fc.in_features
        backbone.fc = nn.Identity()
        self.encoder = backbone

        # Projection head
        self.projector = nn.Sequential(
            nn.Linear(num_ftrs, 256),
            nn.ReLU(),
            nn.Linear(256, 128)
        )

    def forward(self, x):
        h = self.encoder(x)
        z = self.projector(h)
        return F.normalize(z, dim=1)

    def info_nce_loss(self, z):
        batch_size = z.shape[0] // 2
        z = F.normalize(z, dim=1)
        sim = torch.matmul(z, z.T) / self.temperature
        mask = torch.eye(batch_size * 2, device=z.device).bool()
        sim = sim.masked_fill(mask, -9e15)

        positives = torch.cat([
            sim[i, i + batch_size].unsqueeze(0)
            for i in range(batch_size)
        ] + [
            sim[i + batch_size, i].unsqueeze(0)
            for i in range(batch_size)
        ], dim=0)
        numerator = torch.exp(positives)

        denominator = torch.exp(sim).sum(dim=1)
        loss = -torch.log(numerator / denominator).mean()
        return loss

    def training_step(self, batch, batch_idx):
        xi, xj = batch
        x = torch.cat([xi, xj], dim=0)
        z = self(x)
        loss = self.info_nce_loss(z)
        # Sanity print first few steps
        if batch_idx == 0 and self.current_epoch == 0:
            print(f">>> [Debug] First loss = {loss.item():.4f}")
        self.log('train_loss', loss, prog_bar=True)
        return loss

    def configure_optimizers(self):
        return torch.optim.Adam(self.parameters(), lr=self.lr)

In [36]:
# 3) Callback để in epoch bắt đầu
# -----------------------------------
class EpochPrintCallback(Callback):
    def on_train_epoch_start(self, trainer, pl_module):
        print(f"\n>>> Starting Epoch {trainer.current_epoch+1}/{trainer.max_epochs} at {time.time():.1f}s")


In [37]:
# 4) Transforms & DataLoader
# -----------------------------------
simclr_transform = transforms.Compose([
    transforms.RandomResizedCrop(224),
    transforms.RandomHorizontalFlip(),
    transforms.RandomApply([transforms.ColorJitter(0.8,0.8,0.8,0.2)], p=0.8),
    transforms.RandomGrayscale(p=0.2),
    transforms.ToTensor()
])

dataset = SimCLRDataset(
    csv_file='../data/isic2018/labels/train_unlabeled.csv',
    transform=simclr_transform
)
loader = DataLoader(
    dataset,
    batch_size=128,
    shuffle=True,
    num_workers=0,
    drop_last=True,
    pin_memory=True,
    persistent_workers=False
)


✅ Loaded 10015 samples from ../data/isic2018/labels/train_unlabeled.csv


In [38]:
import torch
print("CUDA available:", torch.cuda.is_available())
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("Using device:", device)

model = SimCLR().to(device)
# Kiểm tra tham số đầu tiên
p = next(model.parameters())
print("Parameter device:", p.device)   # phải in cuda:0


CUDA available: True
Using device: cuda
Parameter device: cuda:0


In [39]:
import time, torch
# 1) Kiểm tra GPU
print("CUDA:", torch.cuda.is_available())
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# 2) Dummy data loader
batch = next(iter(loader))  # loader đã được tạo như trước
xi, xj = batch
xi = xi.to(device)
xj = xj.to(device)

# 3) Forward + loss
model = SimCLR().to(device)
start = time.time()
z = model(torch.cat([xi, xj], dim=0))
loss = model.info_nce_loss(z)
loss.backward()
end = time.time()

print(f"Step time: {end-start:.3f}s, loss={loss.item():.4f}")
print("Model device:", next(model.parameters()).device)
print("Data device:", xi.device)


CUDA: True
Step time: 12.569s, loss=5.5300
Model device: cuda:0
Data device: cuda:0


In [40]:
# 5) Trainer & Fit
# -----------------------------------
trainer = pl.Trainer(
    max_epochs=100,
    accelerator='gpu' if torch.cuda.is_available() else 'cpu',
    callbacks=[
        TQDMProgressBar(refresh_rate=1),
        EpochPrintCallback()
    ],
    enable_progress_bar=True,
    enable_model_summary=True,
)

model = SimCLR()
trainer.fit(model, loader)

GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
HPU available: False, using: 0 HPUs
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

  | Name      | Type       | Params | Mode 
-------------------------------------------------
0 | encoder   | ResNet     | 11.2 M | train
1 | projector | Sequential | 164 K  | train
-------------------------------------------------
11.3 M    Trainable params
0         Non-trainable params
11.3 M    Total params
45.363    Total estimated model params size (MB)
72        Modules in train mode
0         Modules in eval mode
c:\Users\user\miniconda3\envs\MLF-CoDA-Project\Lib\site-packages\pytorch_lightning\trainer\connectors\data_connector.py:425: The 'train_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=15` in the `DataLoader` to improve performance.


Training: |          | 0/? [00:00<?, ?it/s]


>>> Starting Epoch 1/100 at 1748329725.7s
>>> [Debug] First loss = 5.5302

>>> Starting Epoch 2/100 at 1748331010.3s

>>> Starting Epoch 3/100 at 1748332301.0s

>>> Starting Epoch 4/100 at 1748333707.9s

>>> Starting Epoch 5/100 at 1748335058.2s

>>> Starting Epoch 6/100 at 1748336358.2s

>>> Starting Epoch 7/100 at 1748337658.3s

>>> Starting Epoch 8/100 at 1748338948.0s

>>> Starting Epoch 9/100 at 1748340285.7s

>>> Starting Epoch 10/100 at 1748341629.4s

>>> Starting Epoch 11/100 at 1748342978.4s

>>> Starting Epoch 12/100 at 1748344405.6s

>>> Starting Epoch 13/100 at 1748345758.5s

>>> Starting Epoch 14/100 at 1748347064.1s

>>> Starting Epoch 15/100 at 1748348373.2s

>>> Starting Epoch 16/100 at 1748349644.4s

>>> Starting Epoch 17/100 at 1748350995.9s

>>> Starting Epoch 18/100 at 1748352278.2s

>>> Starting Epoch 19/100 at 1748353587.9s

>>> Starting Epoch 20/100 at 1748354921.7s

>>> Starting Epoch 21/100 at 1748356247.1s

>>> Starting Epoch 22/100 at 1748357678.4s

>>> Star

`Trainer.fit` stopped: `max_epochs=100` reached.


In [41]:
torch.save(model.encoder.state_dict(), "../checkpoints/simclr_encoder.pth")

In [42]:
torch.save(model.projector.state_dict(), "../checkoints/simclr_projector.pth")