In [1]:
import os
import gc
import glob
import sys
import time
import random
import logging
import argparse
from datetime import datetime
from pathlib import Path

import cv2
import numpy as np
import pandas as pd
import tqdm

from sklearn.model_selection import train_test_split, KFold

import timm
import torch
import torch.nn as nn
from torch.utils.data import Dataset, Subset
from torchvision import datasets, transforms
from PIL import Image

logging.basicConfig(level=logging.INFO)
log = logging.getLogger(__name__)

SEED = 111

random.seed(SEED)
np.random.seed(SEED)
torch.manual_seed(SEED)
torch.cuda.manual_seed(SEED)

torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark     = False


class ImageDataset(Dataset):
    def __init__(self, data_df, transform=None):

        self.data_df = data_df
        self.transform = transform

    def __getitem__(self, idx):
        # достаем имя изображения и ее лейбл
        image_name, label = self.data_df.iloc[idx]['ID_img'], self.data_df.iloc[idx]['class']

        # читаем картинку. read the image
        image = cv2.imread(f"../train/{image_name}")
        image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
        image = Image.fromarray(image)
        
        # преобразуем, если нужно. transform it, if necessary
        if self.transform:
            image = self.transform(image)
        
        return image, torch.tensor(label).long()
    
    def __len__(self):
        return len(self.data_df)

transform_train = transforms.Compose([
    transforms.RandomAffine(15,  translate=(0.15, 0.15)),
    transforms.RandomResizedCrop((384, 384), scale=(0.70, 1.3)),
    transforms.ColorJitter(brightness=0.1, contrast=0.1, saturation=0.1, hue=0.05),
    transforms.RandomVerticalFlip(0.05),
    transforms.RandomHorizontalFlip(),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406],
                          std=[0.229, 0.224, 0.225]),
])

transform_valid = transforms.Compose([
    transforms.Resize((384, 384)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406],
                          std=[0.229, 0.224, 0.225]),
])

data = []
for cl in [0, 1, 2]:
    for image_name in glob.glob(f"../2906/{cl}/*.jpg"):
        data.append({"ID_img":Path(image_name).name, "class":cl})
    for image_name in glob.glob(f"../2906/{cl}/*.jpeg"):
        data.append({"ID_img":Path(image_name).name, "class":cl})
data_df = pd.DataFrame(data)

full_dataset_train = ImageDataset(data_df, transform_train)
full_dataset_val   = ImageDataset(data_df, transform_valid)

In [2]:
fold_idx = []
for rnd in [SEED]:
    cv = KFold(n_splits=7, random_state=rnd, shuffle=True)
    fold_idx.append(list(cv.split(data_df)))

In [3]:
def train(model:nn.Module, criterion, optimizer, train_dataloader, test_dataloader, epochs:int=10, is_init:bool=False, it=0):
    train_loss_log = []
    val_loss_log = []
    
    train_acc_log = []
    val_acc_log = []
    
    best_acc = 0
    
    for epoch in range(epochs):
        model.train()
        train_loss = 0.
        train_size = 0
        
        train_pred = 0.

        for imgs, labels in train_dataloader:
            optimizer.zero_grad()

            imgs = imgs.cuda()
            labels = labels.cuda()

            y_pred = model(imgs)

            loss = criterion(y_pred, labels)
            loss.backward()
            
            train_loss += loss.item()
            train_size += y_pred.size(0)
            train_loss_log.append(loss.data / y_pred.size(0))
            
            train_pred += (y_pred.argmax(1) == labels).sum()

            optimizer.step()

        if is_init:
            continue

        train_acc_log.append(train_pred / train_size)

        val_loss = 0.
        val_size = 0
        
        val_pred = 0.
        
        model.eval()
        
        with torch.no_grad():
            for imgs, labels in test_dataloader:
                
                imgs = imgs.cuda()
                labels = labels.cuda()
                
                pred = model(imgs)
                loss = criterion(pred, labels)
                
                val_loss += loss.item()
                val_size += pred.size(0)
                
                val_pred += (pred.argmax(1) == labels).sum()
        

        val_loss_log.append(val_loss / val_size)
        val_acc_log.append(val_pred / val_size)
        
        if best_acc <= val_pred / val_size:
            best_acc = val_pred / val_size
            
            torch.save(model.state_dict(), f"checkpoints/{it:03d}_2906.pth")

    if not is_init:
        return best_acc

In [4]:
start_idx = 0

for i, fold in enumerate(fold_idx[start_idx:], start=start_idx):
    print(f'start iter {i}')
    val_preds = np.zeros((len(data_df), 3))
    train_preds = np.zeros((len(data_df), 3))
    for y, (train_idx, val_idx) in enumerate(fold):
        print(f'start fold {y}')
        start = time.time()
        torch.cuda.empty_cache()
        model = timm.models.swin_large_patch4_window12_384(pretrained=True)
        model.head = nn.Sequential(nn.Linear(model.head.weight.shape[1], 4096),
                        nn.Dropout(0.20),
                        nn.ReLU(),
                        nn.Linear(4096, 3))
        nn.init.xavier_normal_(model.head[0].weight)
        nn.init.zeros_(model.head[0].bias)
        nn.init.xavier_normal_(model.head[-1].weight)
        nn.init.zeros_(model.head[-1].bias)
        model = model.cuda()
        criterion = torch.nn.CrossEntropyLoss()
        dataset_train = Subset(full_dataset_train, train_idx)
        dataset_val   = Subset(full_dataset_val, val_idx)
        train_loader = torch.utils.data.DataLoader(dataset=dataset_train,
                                           batch_size=6,
                                           shuffle=True,
                                           pin_memory=True,
                                           num_workers=6)

        valid_loader = torch.utils.data.DataLoader(dataset=dataset_val,
                                           batch_size=4,
                                           shuffle=False,
                                           pin_memory=True,
                                           num_workers=4)
        optimizer = torch.optim.AdamW(model.head.parameters(), lr=1e-5)

        train(model, criterion, optimizer, train_loader, valid_loader, epochs=1, is_init=True)

        optimizer = torch.optim.AdamW(model.head.parameters(), lr=1e-4)
        best_acc = train(model, criterion, optimizer, train_loader, valid_loader, epochs=40, it=i * 100 + y)

        model.cpu()
        del model
        torch.cuda.empty_cache()
        gc.collect()
        print(f"total {time.time() - start:.1f} sec, best acc {best_acc:.3f}")
        

start iter 0
start fold 0


  return _VF.meshgrid(tensors, **kwargs)  # type: ignore[attr-defined]
INFO:timm.models.helpers:Loading pretrained weights from url (https://github.com/SwinTransformer/storage/releases/download/v1.0.0/swin_large_patch4_window12_384_22kto1k.pth)


total 2650.9 sec, best acc 0.974
start fold 1


INFO:timm.models.helpers:Loading pretrained weights from url (https://github.com/SwinTransformer/storage/releases/download/v1.0.0/swin_large_patch4_window12_384_22kto1k.pth)


total 2674.6 sec, best acc 0.949
start fold 2


INFO:timm.models.helpers:Loading pretrained weights from url (https://github.com/SwinTransformer/storage/releases/download/v1.0.0/swin_large_patch4_window12_384_22kto1k.pth)


total 2668.2 sec, best acc 0.949
start fold 3


INFO:timm.models.helpers:Loading pretrained weights from url (https://github.com/SwinTransformer/storage/releases/download/v1.0.0/swin_large_patch4_window12_384_22kto1k.pth)


total 2657.8 sec, best acc 0.962
start fold 4


INFO:timm.models.helpers:Loading pretrained weights from url (https://github.com/SwinTransformer/storage/releases/download/v1.0.0/swin_large_patch4_window12_384_22kto1k.pth)


total 2674.0 sec, best acc 0.961
start fold 5


INFO:timm.models.helpers:Loading pretrained weights from url (https://github.com/SwinTransformer/storage/releases/download/v1.0.0/swin_large_patch4_window12_384_22kto1k.pth)


total 2664.5 sec, best acc 0.974
start fold 6


INFO:timm.models.helpers:Loading pretrained weights from url (https://github.com/SwinTransformer/storage/releases/download/v1.0.0/swin_large_patch4_window12_384_22kto1k.pth)


total 2666.4 sec, best acc 0.961
