<a href="https://colab.research.google.com/github/RyosukeHanaoka/TechTeacher/blob/main/swin_with_kfold.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
!pip install vit_pytorch timm

from __future__ import print_function

import glob
import os
import random
import cv2
import shutil
from google.colab.patches import cv2_imshow
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from PIL import Image
from torch.optim.lr_scheduler import StepLR
from torch.utils.data import DataLoader, Dataset, Subset
from torchvision import datasets, transforms
from tqdm.notebook import tqdm
from vit_pytorch.efficient import ViT
from pathlib import Path
import seaborn as sns
import timm
from pprint import pprint
from sklearn.model_selection import StratifiedKFold

# シードの設定
batch_size = 64
epochs = 50
lr = 0.3 * 1e-3
gamma = 0.8
seed = 42  # 乱数のシード
def seed_everything(seed):
    random.seed(seed)
    os.environ['PYTHONHASHSEED'] = str(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.deterministic = True

seed_everything(seed)

# GPUの設定
device = 'cuda' if torch.cuda.is_available() else 'cpu'

# 前処理の定義
train_transforms = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.RandomRotation(degrees=15),
    transforms.RandomAffine(degrees=0, translate=(0.1, 0.1)),
    transforms.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2, hue=0.1),
    transforms.RandomPerspective(distortion_scale=0.2, p=0.5),
    transforms.ToTensor(),
    transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
])

val_transforms = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
])

# 画像のパス
data_dir_RA = '/content/drive/MyDrive/crp>0.3/image_crp0.3_patient'
data_dir_nonRA = '/content/drive/MyDrive/images_ra_and_nonra/image_nonra_patient'

# データセットの定義
class CustomDataset(Dataset):
    def __init__(self, ra_paths, nonra_paths, transform=None):
        self.ra_paths = ra_paths
        self.nonra_paths = nonra_paths
        self.transform = transform
        self.data = self.ra_paths + self.nonra_paths
        self.labels = [1] * len(self.ra_paths) + [0] * len(self.nonra_paths)

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        try:
            img_path = self.data[idx]
            image = Image.open(img_path).convert('RGB')
            label = self.labels[idx]
            if self.transform:
                image = self.transform(image)
            return image, label
        except UnidentifiedImageError:
            print(f"Couldn't read image at index {idx}: UnidentifiedImageError")
            # 代わりの画像を返す
            img = Image.new('RGB', (224, 224), color='gray')
            if self.transform:
                img = self.transform(img)
            label = 0  # または適切なクラスのインデックス
            return img, label

# RAとnonRAの画像パスを取得
ra_image_paths = glob.glob(os.path.join(data_dir_RA, '*'))
nonra_image_paths = glob.glob(os.path.join(data_dir_nonRA, '*'))

# 全データセットを作成
full_dataset = CustomDataset(ra_image_paths, nonra_image_paths, transform=None)

# ラベルを取得
labels = full_dataset.labels

# Stratified K-Foldの設定
kf = StratifiedKFold(n_splits=5, shuffle=True, random_state=seed)

# 結果を記録するリスト
all_train_acc = []
all_val_acc = []
all_train_loss = []
all_val_loss = []

# K-Fold Cross Validationのループ
for fold, (train_idx, val_idx) in enumerate(kf.split(full_dataset.data, labels)):
    print(f'Fold {fold + 1}')
    # 各フォールドのデータセットを作成
    train_dataset = Subset(full_dataset, train_idx)
    val_dataset = Subset(full_dataset, val_idx)

    # 各フォールドのデータにトランスフォームを適用
    train_dataset.dataset.transform = train_transforms
    val_dataset.dataset.transform = val_transforms

    # データローダーの設定
    train_loader = DataLoader(
        dataset=train_dataset,
        batch_size=batch_size // 4,
        shuffle=True,
        num_workers=2,
        pin_memory=True,
        persistent_workers=True,
        prefetch_factor=2
    )

    val_loader = DataLoader(
        dataset=val_dataset,
        batch_size=batch_size // 4,
        shuffle=False,
        num_workers=2,
        pin_memory=True,
        persistent_workers=True
    )

    # モデルの初期化
    model = timm.create_model('swin_base_patch4_window7_224.ms_in1k', pretrained=True, num_classes=2)
    model.to(device)

    # 損失関数とオプティマイザ
    criterion = nn.CrossEntropyLoss()
    optimizer = optim.Adam(model.parameters(), lr=lr)
    scheduler = StepLR(optimizer, step_size=10, gamma=gamma)

    # 訓練と評価
    train_acc_list, val_acc_list, train_loss_list, val_loss_list = train_with_gradient_accumulation_kfold(
        model, train_loader, val_loader, criterion, optimizer, scheduler, device, epochs
    )

    # 結果を保存
    all_train_acc.append(train_acc_list)
    all_val_acc.append(val_acc_list)
    all_train_loss.append(train_loss_list)
    all_val_loss.append(val_loss_list)

# 関数の定義（K-Foldに対応）
def train_with_gradient_accumulation_kfold(model, train_loader, val_loader, criterion, optimizer, scheduler, device, epochs):
    accumulation_steps = 4
    effective_batch_size = batch_size
    actual_batch_size = effective_batch_size // accumulation_steps

    train_acc_list = []
    val_acc_list = []
    train_loss_list = []
    val_loss_list = []

    scaler = torch.cuda.amp.GradScaler()

    for epoch in range(epochs):
        # 訓練フェーズ
        model.train()
        epoch_loss = 0
        epoch_accuracy = 0
        optimizer.zero_grad()

        for i, (data, label) in enumerate(tqdm(train_loader)):
            with torch.cuda.amp.autocast():
                data = data.to(device)
                label = label.to(device)

                output = model(data)
                loss = criterion(output, label)
                loss = loss / accumulation_steps

            scaler.scale(loss).backward()

            if ((i + 1) % accumulation_steps == 0) or (i + 1 == len(train_loader)):
                scaler.step(optimizer)
                scaler.update()
                optimizer.zero_grad(set_to_none=True)

            acc = (output.argmax(dim=1) == label).float().mean()
            epoch_accuracy += acc / len(train_loader)
            epoch_loss += loss.item() * accumulation_steps / len(train_loader)

            del data, label, output, loss
            if (i + 1) % (accumulation_steps * 2) == 0:
                torch.cuda.empty_cache()

        # 検証フェーズ
        model.eval()
        with torch.no_grad():
            epoch_val_accuracy = 0
            epoch_val_loss = 0
            for data, label in val_loader:
                data = data.to(device)
                label = label.to(device)

                val_output = model(data)
                val_loss = criterion(val_output, label)

                acc = (val_output.argmax(dim=1) == label).float().mean()
                epoch_val_accuracy += acc / len(val_loader)
                epoch_val_loss += val_loss.item() / len(val_loader)

                del data, label, val_output, val_loss

            torch.cuda.empty_cache()

        # 結果の記録
        train_acc_list.append(epoch_accuracy)
        val_acc_list.append(epoch_val_accuracy)
        train_loss_list.append(epoch_loss)
        val_loss_list.append(epoch_val_loss)

        print(
            f"Epoch : {epoch+1} - "
            f"loss : {epoch_loss:.4f} - acc: {epoch_accuracy:.4f} - "
            f"val_loss : {epoch_val_loss:.4f} - val_acc: {epoch_val_accuracy:.4f}\n"
        )

        scheduler.step()

    return train_acc_list, val_acc_list, train_loss_list, val_loss_list
