In [1]:
import os
import time
import random
import argparse
import datetime
import numpy as np

import torch
import torch.backends.cudnn as cudnn
import torch.distributed as dist

from timm.utils import accuracy, AverageMeter
from torch.nn import CrossEntropyLoss
from tqdm import tqdm

from models import build_model
from models.swin_transformer import SwinTransformer

import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader
import torch.optim as optim
import torch.cuda.amp

import torchvision 
import torchvision.datasets as datasets
import torchvision.transforms as transforms
from torchvision.utils import make_grid
from torchvision.transforms import InterpolationMode
from torchvision import transforms, datasets
from torch.utils.data import DataLoader
from timm.data import create_transform

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [2]:
batch_size = 4
lr = 0.0001
weight_decay = 0.000001
t_total = 10000
eval_every = 20

pt_path = "./img"
tr_path = "./img"
ts_path = "./img"

In [3]:
def build_transform():

    if True:
        # this should always dispatch to transforms_imagenet_train
        transform = create_transform(
            input_size=224,
            is_training=True,
            color_jitter=0.4,
            auto_augment='rand-m9-mstd0.5-inc1',
            interpolation=InterpolationMode.BICUBIC,
            re_prob=0.25,
            re_mode='pixel',
            re_count=1
        )

        
        return transform
    
def get_loader(pt_path, tr_path, ts_path, pt_batch, tr_batch, ts_batch):
    transform_train = transforms.Compose([
        transforms.RandomResizedCrop((224, 224), scale=(0.05, 1.0)), # center crop으로 변경 필요
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]),
    ])
    transform_test = transforms.Compose([
        transforms.Resize((224, 224)),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]),
    ])
    pretrain_transform = build_transform()
    
    pretrain_dataset = torchvision.datasets.ImageFolder(pt_path,transform = pretrain_transform)
    train_dataset = torchvision.datasets.ImageFolder(tr_path,transform = transform_train)
    test_dataset = torchvision.datasets.ImageFolder(ts_path,transform = transform_test)

    pretrain_train_loader = DataLoader(pretrain_dataset,
                              batch_size=pt_batch,
                              num_workers=8,
                              pin_memory=True,
                            shuffle=True
                             )
    train_loader = DataLoader(train_dataset,
                              batch_size=tr_batch,
                              num_workers=8,
                              pin_memory=True,
                              shuffle=True
                             )
    test_loader = DataLoader(test_dataset,
                              batch_size=ts_batch,
                              num_workers=8,
                              pin_memory=True
                             )
    return pretrain_train_loader, train_loader, test_loader

def build_model():
    
    model = SwinTransformer(img_size=224,
                            patch_size=2,
                            in_chans=3,
                            num_classes=1,
                            embed_dim=96,
                            depths=[2, 2, 6, 2],
                            num_heads=[3, 6, 12, 24],
                            window_size=7,
                            mlp_ratio=4,
                            qkv_bias=True,
                            qk_scale=None,
                            drop_rate=0.0,
                            drop_path_rate=0.1,
                            ape=False,
                            patch_norm=True,
                            use_checkpoint=False)
        #raise NotImplementedError(f"Unkown model: {model_type}")

    return model

In [4]:
model = build_model()
model.load_state_dict(torch.load('pretrained_swin_wo_head.npz',map_location=torch.device('cpu')),strict=False)

optimizer = optim.AdamW(model.parameters(),  eps=1e-08, betas=(0.9, 0.999), lr=lr, weight_decay=0.00001)
_, train_loader, test_loader = get_loader(pt_path = pt_path, tr_path = tr_path, ts_path = ts_path, pt_batch = batch_size, tr_batch = batch_size, ts_batch = batch_size)


  return _VF.meshgrid(tensors, **kwargs)  # type: ignore[attr-defined]


In [5]:
def simple_accuracy(preds, labels):
    return (preds == labels).mean()

def valid(model, test_loader):
    # Validation!
    eval_losses = AverageMeter()


    model.eval()
    all_preds, all_label = [], []
    epoch_iterator = tqdm(test_loader,
                          desc="Validating... (loss=X.X)",
                          bar_format="{l_bar}{r_bar}",
                          dynamic_ncols=True,)
    loss_fct = F.binary_cross_entropy_with_logits
    with torch.no_grad():
        for step, batch in enumerate(test_loader):
            batch = tuple(t.to(device) for t in batch)
            x, y = batch
            y = y.view(-1)

            with torch.no_grad():
                logits = model(x).squeeze(1)
                eval_loss = loss_fct(logit,y.type(torch.FloatTensor))
                eval_losses.update(eval_loss.item())

                preds = torch.argmax(logits, dim=-1)
                #print(preds.shape)

            if len(all_preds) == 0:
                all_preds.append(preds.detach().cpu().numpy())
                all_label.append(y.detach().cpu().numpy())
            else:
                all_preds[0] = np.append(
                    all_preds[0], preds.detach().cpu().numpy(), axis=0
                )
                all_label[0] = np.append(
                    all_label[0], y.detach().cpu().numpy(), axis=0
                )
        #epoch_iterator.set_description("Validating... (loss=%2.5f)" % eval_losses.val)


    all_preds, all_label = all_preds[0], all_label[0]

    accuracy = simple_accuracy(all_preds, all_label)


    print("\n")
    print("Valid Loss: %2.5f" % eval_losses.avg)
    print("Valid Accuracy: %2.5f" % (accuracy*100))
    return accuracy

In [6]:
loss_fct = F.binary_cross_entropy_with_logits
losses = AverageMeter()
global_step = 0
scaler = torch.cuda.amp.GradScaler()

while True:
    model.train()
    epoch_iterator = tqdm(train_loader,
                          desc="Training (X / X Steps) (loss=X.X)",
                          bar_format="{l_bar}{r_bar}",
                          dynamic_ncols=True
                         )
    
    for step, batch in enumerate(epoch_iterator):
        batch = tuple(t.to(device) for t in batch)
        optimizer.zero_grad()
        x, y = batch
        y = y.view(-1)

        
    
        with torch.cuda.amp.autocast():
            logit = model(x).squeeze(1)
            loss = loss_fct(logit,y.type(torch.FloatTensor))
        losses.update(loss.item())

        scaler.scale(loss).backward()
        scaler.unscale_(optimizer)

        torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)

        scaler.step(optimizer)
        scaler.update()
            

        global_step += 1
        epoch_iterator.set_description(
                    "Training (%d / %d Steps) (loss=%2.5f)" % (global_step, t_total, losses.val)
                )
        
        if global_step % eval_every == 0:
            torch.save(model.state_dict(),'fine_tuned_swin.npz')
            #acc = valid(model, train_loader)

Training (1 / 10000 Steps) (loss=0.83463):  20%|| 1/5 [00:12<00:50, 12.68s/it]


KeyboardInterrupt: 

In [9]:
def count_param(model):
    return sum(p.numel() for p in model.parameters() if p.requires_grad)
count_param(model)
