## Import SimSiam

In [74]:
from tqdm import tqdm
import torch
import torchvision
from torch import nn
from lightly.loss import NegativeCosineSimilarity
from lightly.models.modules import SimSiamPredictionHead, SimSiamProjectionHead
from lightly.transforms import SimSiamTransform
from datasets import load_dataset
from torch.utils.data import DataLoader, Dataset
import wandb

## parameters 

In [75]:
# 訓練參數
parameters = {
    "learning_rate": 0.06,
    "epochs": 200,
    "batch_size": 32,
}

In [None]:
wandb.init(
    # set the wandb project where this run will be logged
    project="simsaim_tiny_imagenet",

    # track hyperparameters and run metadata
    config={
    "learning_rate": 0.6,
    "architecture": "SimSaim",
    "dataset": "tiny_imagenet",
    "epochs": 200,
    }
)

## 建立 Model

In [76]:
class SimSiam(nn.Module):
    def __init__(self, backbone):
        super().__init__()
        self.backbone = backbone
        self.projection_head = SimSiamProjectionHead(512, 512, 128)
        self.prediction_head = SimSiamPredictionHead(128, 64, 128)

    def forward(self, x):
        f = self.backbone(x).flatten(start_dim=1)
        z = self.projection_head(f)
        p = self.prediction_head(z)
        z = z.detach()
        return z, p

resnet = torchvision.models.resnet18()
backbone = nn.Sequential(*list(resnet.children())[:-1])
model = SimSiam(backbone)

In [77]:
# 設定設備
device = "cuda" if torch.cuda.is_available() else "cpu"
model.to(device)

SimSiam(
  (backbone): Sequential(
    (0): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
    (1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (2): ReLU(inplace=True)
    (3): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
    (4): Sequential(
      (0): BasicBlock(
        (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
        (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (relu): ReLU(inplace=True)
        (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
        (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      )
      (1): BasicBlock(
        (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
        (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True

## 資料處理

### 照片增強處理與載入資料集

In [84]:
import torchvision.transforms as transforms
from lightly.transforms import SimSiamTransform

# 更新 SimSiamTransform，加入 Grayscale 轉換
# Tiny Imagenet 大小為64
transform = transforms.Compose([
    transforms.Grayscale(num_output_channels=3),  # 將灰度圖轉為 3 通道
    SimSiamTransform(input_size=64)
])

ds = load_dataset("zh-plus/tiny-imagenet")

### 自行定義DataSet

In [85]:
# 自定義 Dataset 類，應用 transform
class TinyImageNetDataset(Dataset):
    def __init__(self, dataset, transform):
        self.dataset = dataset
        self.transform = transform

    def __len__(self):
        return len(self.dataset)

    def __getitem__(self, idx):
        img = self.dataset[idx]["image"]
        img0, img1 = self.transform(img), self.transform(img)
        return img0, img1

In [86]:
train_ds = TinyImageNetDataset(ds["train"], transform=transform)

### 設定 DataLoader

In [87]:
# 設定 DataLoader
dataloader = DataLoader(
    train_ds,
    batch_size=parameters["batch_size"],
    shuffle=True,
    num_workers=0,
    pin_memory=True  # 禁用 pin_memory
)

In [88]:
# 定義損失函數和優化器
criterion = NegativeCosineSimilarity()
optimizer = torch.optim.SGD(model.parameters(), lr=parameters["learning_rate"])

In [83]:
# 開始訓練
print("Starting Training")
for epoch in range(parameters["epochs"]):
    total_loss = 0
    # 使用 tqdm 包裝 dataloader，顯示 batch 進度
    with tqdm(dataloader, unit="batch") as tepoch:
        tepoch.set_description(f"Epoch {epoch+1}")
        for batch in tepoch:
            x0, x1 = batch[0]
            x0 = x0.to(device)
            x1 = x1.to(device)
            
            # 更新解包為兩個輸出
            z0, p0 = model(x0)
            z1, p1 = model(x1)
            
            loss = 0.5 * (criterion(z0, p1) + criterion(z1, p0))
            total_loss += loss.detach()
            loss.backward()
            optimizer.step()
            optimizer.zero_grad()

            # 在進度條中顯示損失值
            tepoch.set_postfix(loss=loss.item())
    avg_loss = total_loss / len(dataloader)
    # log metrics to wandb
    wandb.log({ "Loss": avg_loss})
    print(f"Epoch: {epoch+1:>02}, Average Loss: {avg_loss:.5f}")
wandb.finish()

Starting Training


Epoch 1:  58%|█████▊    | 1802/3125 [06:59<04:33,  4.84batch/s, loss=-0.74] 

## 儲存模型

In [None]:
torch.save(model.state_dict(), "./pretrainModel/simsaim/pretrained_simsiam.pth")