## simclr

In [1]:
import os
import torch
import torch.nn as nn
import torchvision.transforms as transforms
import torchvision.models as models
from torchvision.models import ResNet18_Weights
import pytorch_lightning as pl
from lightning.pytorch.loggers import WandbLogger
import lightly 
import matplotlib.pyplot as plt
from sklearn.neighbors import NearestNeighbors
from sklearn.preprocessing import normalize
from PIL import Image
import numpy as np
from lightly import data
import glob
from torch.utils.data import Dataset, DataLoader
from torch.utils.data.sampler import BatchSampler

In [3]:
os.cpu_count()

80

In [2]:
num_workers = int(os.cpu_count()//2)
seed = 1
max_epochs = 20
#input_size = 128
#num_ftrs = 32

In [3]:
pl.seed_everything(seed)

Seed set to 1


1

In [27]:
path_to_data_train = '/home/abababam1/HandwrittenTextAlign/PRMU/simclr/data/train'
path_to_data_test = '/home/abababam1/HandwrittenTextAlign/PRMU/simclr/data/test'

In [28]:
transform_train = transforms.Compose([
    transforms.Resize((64, 63),antialias=True),  # 画像のサイズ変更
    transforms.Grayscale(num_output_channels=1), #single-channel
    transforms.RandomAffine(degrees=(-20, 20), scale=(0.8, 1.2), fill = 255),
    transforms.ToTensor(),           # テンソルに変換
    transforms.Normalize((0.5,), (0.5,)) #single-channel normalization
])
transform_test = transforms.Compose([
    transforms.Resize((64, 63),antialias=True),  # 画像のサイズ変更
    transforms.Grayscale(num_output_channels=1), #single-channel
    transforms.ToTensor(),           # テンソルに変換
    transforms.Normalize((0.5,), (0.5,)) #single-channel normalization
])

transform_simclr = transforms.Compose([
    transforms.Resize((64, 63),antialias=True),  # 画像のサイズ変更
    transforms.Grayscale(num_output_channels=1), #single-channel
    transforms.RandomAffine(degrees=(-20, 20), scale=(0.8, 1.2), fill = 255),
    #transforms.ToTensor(),           # テンソルに変換
    transforms.Normalize((0.5,), (0.5,)) #single-channel normalization
])

#### データローダー

・バッチサイズ（字ごとに異なる）
・画足りない情報: ラベルに入れる
・バッチサイズの指定を変える

In [29]:
def label_data_dict(path_to_data):
    d = dict() # 画像に対しラベル
    class_indices = dict() # ラベルに対し画像が何個あるか
    for idx, path in enumerate(glob.glob(f'{path_to_data}//*/*.png')):
        char = path.split('/')[-2]
        d[path] = char
        
        if char not in class_indices:
            class_indices[char] = [idx]
        else:
            class_indices[char] += [idx]
    return d, class_indices

class CustomDataset(Dataset):
    def __init__(self, path_to_data, transform=None):
        self.image_paths = []
        self.labels = []
        #self.classes = classes
        
        data, _ = label_data_dict(path_to_data)
        self.image_paths.extend(list(data.keys()))
        self.labels.extend(list(data.values()))

        self.classes = sorted(set(self.labels))

        self.transform = transform

    def __len__(self):
        return len(self.image_paths)

    def __getitem__(self, index):
        image_path = self.image_paths[index]
        label = self.labels[index]

        label_index = self.classes.index(label)

        image = Image.open(image_path).convert('RGB')

        if self.transform:
            image = self.transform(image)


        return image, label_index

# class_indices: {'label':[data], ...}
class ClassBatchSampler:
    def __init__(self, class_indices):
        self.class_indices = class_indices
        self.classes = list(class_indices.keys())
        #self.current_class = 0

    def __iter__(self):
        # 各クラスのインデックスを順に返す
        for class_label in self.classes:
            indices = self.class_indices[class_label]
            #print(f"Sampling indices for class {class_label}: {indices}", flush=True)  # デバッグ出力
            yield indices

    def __len__(self):
        return len(self.classes)
        

#----------------------------------------------------------------------

# データセットを作成
#dataset = CustomDataset(path_to_data, transform=transform)

#_, class_indices = label_data_dict(path_to_data)

# サンプラーを使ってデータローダーを作成
#sampler = ClassBatchSampler(class_indices)
#dataloader = DataLoader(dataset, batch_sampler=sampler)

# データローダーでクラスごとにデータを取得
#for batch_idx, (data, labels) in enumerate(dataloader):
#    print(f"Batch {batch_idx}:")
#    print(f"Data: {data}")
#    print(f"Labels: {labels}")
#    print(f"Batch size: {len(data)}")

In [41]:
# データセットを作成
dataset_train = CustomDataset(path_to_data_train, transform=transform_train)
dataset_test = CustomDataset(path_to_data_test, transform=transform_test)

_, class_indices_train = label_data_dict(path_to_data_train)
_, class_indices_test = label_data_dict(path_to_data_test)

# サンプラーを使って訓練データローダーを作成
sampler = ClassBatchSampler(class_indices_train)
dataloader_train = DataLoader(dataset_train, batch_sampler=sampler, num_workers=num_workers)

# サンプラーを使ってテストデータローダーを作成
sampler = ClassBatchSampler(class_indices_test)
dataloader_test = DataLoader(dataset_test, batch_sampler=sampler)

In [None]:
collate_fn = lightly.data.SimCLRCollateFunction(
    input_size=input_size,
    vf_prob=0,
    rr_prob=0.3,
)

dataset_train_simclr = lightly.data.LightlyDataset(
    input_dir=path_to_data
)

dataset_test = lightly.data.LightlyDataset(
    input_dir=path_to_data,
    transform=test_transforms
)

dataloader_train_simclr = DataLoader(
    dataset_train_simclr,
    batch_size=sampler,
    shuffle=True,
    collate_fn=collate_fn,
    drop_last=True,
    num_workers=num_workers
)

dataloader_test = DataLoader(
    dataset_test,
    batch_size=sampler,
    shuffle=False,
    drop_last=False,
    num_workers=num_workers
)

#### モデル

In [26]:
from lightly.models.modules.heads import SimCLRProjectionHead
from lightly.loss import NTXentLoss


class SimCLRModel(pl.LightningModule):
    def __init__(self, batch_size=10, transform=None):
        super().__init__()
        
        self.batch_size = batch_size 

        # create a ResNet backbone and remove the classification head
        resnet = models.resnet18(weights=ResNet18_Weights.IMAGENET1K_V1, progress=True)
        self.backbone = nn.Sequential(*list(resnet.children())[:-1])
        
        # 最初の畳み込み層を1チャンネル入力に対応させる
        self.backbone[0] = torch.nn.Conv2d(
            in_channels=1,  # 入力チャンネル数を1に変更
            out_channels=resnet.conv1.out_channels,
            kernel_size=resnet.conv1.kernel_size,
            stride=resnet.conv1.stride,
            padding=resnet.conv1.padding,
            bias=resnet.conv1.bias is not None
        )

        hidden_dim = resnet.fc.in_features
        self.projection_head = SimCLRProjectionHead(
            input_dim=hidden_dim,
            hidden_dim=2048,
            output_dim=128,
            num_layers=2,
            batch_norm=True
        )
        
        self.transform = transform

        self.criterion = NTXentLoss()

    def forward(self, x):
        h = self.backbone(x).flatten(start_dim=1)
        z = self.projection_head(h)
        return z

    def training_step(self, batch, batch_idx):
        #print(f"Batch content: {batch}")  # バッチの内容を出力して確認
        (x0, x1), *_ = batch
        z0 = self.forward(x0)
        z1 = self.forward(x1)
        loss = self.criterion(z0, z1)
        self.log("train_loss_ssl", loss)
        return loss
    
    def training_step(self, batch, batch_idx):
        images, labels = batch  # 画像とラベルを分けて取得

        # SimCLRでは、1つの画像に対して2つの異なるビューを生成
        # データ拡張を使って2つのビューを作成
        x0 = self.transform(images)
        x1 = self.transform(images)

        # 順伝播
        z0 = self.forward(x0)
        z1 = self.forward(x1)
        loss = self.criterion(z0, z1)

        # ログに損失を記録
        self.log("train_loss_ssl", loss)
        return loss

    def configure_optimizers(self):# lr=0.075*(self.batch_size)**(1/2)
        #optim = torch.optim.SGD(
        #    self.parameters(), lr=6e-2, momentum=0.9, weight_decay=5e-4
        #)
        optim = torch.optim.Adam(
            self.parameters(), lr=1e-3, weight_decay=5e-4
        )
        scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(
            optim, max_epochs
        )
        return [optim], [scheduler]

#### 訓練

In [9]:
import torch
print(torch.cuda.is_available())  # GPUが使用可能かどうかを確認
#print(torch.cuda.device_count())  # 使用可能なGPUの数
#print(torch.cuda.get_device_name(0))  # GPUの名前
accelerator='gpu' if torch.cuda.is_available() else 'cpu'
accelerator
devices=2 if torch.cuda.is_available() else 1
devices

True


2

In [11]:
import wandb
wandb.login()

Failed to detect the name of this notebook, you can set it manually with the WANDB_NOTEBOOK_NAME environment variable to enable code saving.
[34m[1mwandb[0m: Using wandb-core as the SDK backend. Please refer to https://wandb.me/wandb-core for more information.
[34m[1mwandb[0m: Currently logged in as: [33mabababamb1[0m ([33mabababamb1-tokyo-university-of-science[0m). Use [1m`wandb login --relogin`[0m to force relogin


True

In [9]:
%%time
#gpus = [1] if torch.cuda.is_available() else 0

wandb_logger = WandbLogger(log_model="all")

model = SimCLRModel(batch_size=10, transform=transform_simclr)
trainer = pl.Trainer(
    max_epochs=max_epochs, 
    accelerator='gpu' if torch.cuda.is_available() else 'cpu',
    devices=[0,1,2] if torch.cuda.is_available() else 1,  # GPUが使える場合は[2]、使えない場合は1（CPUコア数）
    strategy="ddp_notebook",  # データ並列 (DataParallel)
    enable_progress_bar=True, # 進捗バーを有効化
    log_every_n_steps=100,  # ログの更新間隔を設定
    logger=wandb_logger,  # ログ機能を無効化
    use_distributed_sampler=False  # 分散サンプラーを無効化
)
trainer.fit(model, dataloader_train)

GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
HPU available: False, using: 0 HPUs
Initializing distributed: GLOBAL_RANK: 1, MEMBER: 2/3
Initializing distributed: GLOBAL_RANK: 0, MEMBER: 1/3
Initializing distributed: GLOBAL_RANK: 2, MEMBER: 3/3
----------------------------------------------------------------------------------------------------
distributed_backend=nccl
All distributed processes registered. Starting with 3 processes
----------------------------------------------------------------------------------------------------

Failed to detect the name of this notebook, you can set it manually with the WANDB_NOTEBOOK_NAME environment variable to enable code saving.
[34m[1mwandb[0m: Using wandb-core as the SDK backend. Please refer to https://wandb.me/wandb-core for more information.
[34m[1mwandb[0m: Currently logged in as: [33mabababamb1[0m ([33mabababamb1-tokyo-university-of-science[0m). Use [1m`wandb login --relogin`[0m to force relogi

LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0,1,2]
LOCAL_RANK: 2 - CUDA_VISIBLE_DEVICES: [0,1,2]
LOCAL_RANK: 1 - CUDA_VISIBLE_DEVICES: [0,1,2]

  | Name            | Type                 | Params | Mode 
-----------------------------------------------------------------
0 | backbone        | Sequential           | 11.2 M | train
1 | projection_head | SimCLRProjectionHead | 1.3 M  | train
2 | criterion       | NTXentLoss           | 0      | train
-----------------------------------------------------------------
12.5 M    Trainable params
0         Non-trainable params
12.5 M    Total params
49.941    Total estimated model params size (MB)
76        Modules in train mode
0         Modules in eval mode


Epoch 19: 100%|██████████| 6379/6379 [05:05<00:00, 20.85it/s, v_num=fb79]

`Trainer.fit` stopped: `max_epochs=20` reached.


Epoch 19: 100%|██████████| 6379/6379 [05:06<00:00, 20.81it/s, v_num=fb79]
CPU times: user 4min 1s, sys: 1min 47s, total: 5min 49s
Wall time: 1h 43min 37s


In [14]:
# モデルの状態を保存
torch.save(model.state_dict(), './simclr/1015.pth')

#### テストデータの埋め込み作成

In [40]:
def generate_embeddings(model, dataloader):
    """Generates representations for all images in the dataloader with
    the given model
    """

    embeddings = []
    filenames = []
    with torch.no_grad():
        #for img, label, fnames in dataloader:
        for img, label in dataloader:
            img = img.to(model.device)
            emb = model.backbone(img).flatten(start_dim=1)
            embeddings.append(emb)
            #filenames.extend(fnames)

    if embeddings:  # embeddingsが空でないことを確認
        embeddings = torch.cat(embeddings, 0)
        embeddings = normalize(embeddings)
        return embeddings, filenames
    else:
        raise RuntimeError("No embeddings generated. Please check your model and dataloader.")



model.eval()
embeddings, filenames = generate_embeddings(model, dataloader_test)
embeddings

RuntimeError: No embeddings generated. Please check your model and dataloader.

In [39]:
for batch in dataloader_test:
    print(batch)  # ここでバッチが正しく取得できているか確認