In [1]:
import os
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
import torchvision.transforms as transforms
import glob
import os
import numpy as np
from PIL import Image
import math
import sys
import random
from byol_pytorch import BYOL
from torchvision import models
import pandas as pd

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
config = {
    'device':'cuda' if torch.cuda.is_available() else 'cpu',
    'train_pth':'/data/dlcv/hw4/mini/train/',
    'csv_pth':'/data/dlcv/hw4/mini/train.csv',
    'best_pth': '/data/allen/hw4model/longep/backbone2.pth',
    'last_pth':'/data/allen/hw4model/longep/backbone2_last.pth',
    'last_byol_pth':'/data/allen/hw4model/longep/byol2_last.pth',
    'bsz':512,
    'lr':1.e-3,
    'epochs':3000,
    'imgsz':128
}
backbone_transform = transforms.Compose([
    transforms.Resize((config['imgsz'], config['imgsz'])),
    transforms.ToTensor(),
])
if config["device"] == "cuda":
    torch.cuda.set_device(2)
print('Device used :', config['device'])

Device used : cuda


In [3]:
def save_checkpoint(checkpoint_path, model, optimizer, scheduler, ep, best_loss):
    state = {'model_state_dict': model.state_dict(),
             'optimizer_state_dict' : optimizer.state_dict(),
             'scheduler_state_dict':scheduler.state_dict(),
             'last_ep':ep,
             'best_loss':best_loss
             }
    torch.save(state, checkpoint_path)
    print('checkpoint saved to {}'.format(checkpoint_path))

def save_model_only(checkpoint_path, model):
    state = {'model_state_dict': model.state_dict(),}
    torch.save(state, checkpoint_path)
    print('model saved to {}'.format(checkpoint_path))

def load_checkpoint(checkpoint_path, device='cpu'):
    checkpoint = torch.load(checkpoint_path, map_location=device)
    return checkpoint["model_state_dict"], checkpoint["optimizer_state_dict"], checkpoint["scheduler_state_dict"], checkpoint['last_ep'], checkpoint['best_loss']

def load_model_only(checkpoint_path, device='cpu'):
    checkpoint = torch.load(checkpoint_path, map_location=device)
    return checkpoint["model_state_dict"]

In [4]:
class DS(Dataset):
    def __init__(self, datapath, csvpath, transform=None) -> None:
        self.transform = transform
        self.data = [] #(imgpath, imgname)
        if csvpath is not None:
            if os.path.exists(csvpath):
                df = pd.read_csv(csvpath)
                self.data = [(os.path.join(datapath, name), name) for name in df['filename']]
            else:
                print(f"Can't find {csvpath}")
                exit(-1)
        else:
            if os.path.exists(datapath):
                paths = glob.glob(os.path.join(datapath, "*"))
                for path in paths:
                    imgname = os.path.split(path)[-1]
                    self.data.append((path, imgname))
            else:
                print(f"Can't open {datapath}")
                exit(-1)
        self.len = len(self.data)
        print(self.len)

    def __getitem__(self, index):
        imgpath, imgname = self.data[index]
        img = Image.open(imgpath)
        if self.transform:
            img = self.transform(img)
        return img

    def __len__(self):
        return self.len

In [5]:
train_loader = DataLoader(DS(datapath=config['train_pth'], csvpath=config['csv_pth'], transform=backbone_transform), shuffle=True, batch_size=config['bsz'], pin_memory=True, num_workers=4)

38400


In [6]:
resnet = models.resnet50(weights=None).to(config['device'])
learner = BYOL(
    resnet,
    image_size = config['imgsz'],
    hidden_layer = 'avgpool',
    moving_average_decay = 0.9995
)
opt = torch.optim.Adam(learner.parameters(), lr=config['lr'], weight_decay=1.5e-6)
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(opt, T_max=config['epochs'])

In [7]:
best_loss = torch.inf
learner.train()
for ep in range(config['epochs']):
    train_loss = 0.
    for idx, imgs in enumerate(train_loader):
        imgs = imgs.to(config['device'])
        loss = learner(imgs)
        opt.zero_grad()
        loss.backward()
        opt.step()
        learner.update_moving_average()
        train_loss += loss.item()
    scheduler.step()
    train_loss /= (idx + 1)
    if train_loss < best_loss:
        best_loss = train_loss
        save_model_only(config['best_pth'], resnet)
    #save last
    save_model_only(config['last_pth'], resnet)
    save_checkpoint(config['last_byol_pth'], learner, opt, scheduler, ep, best_loss)
    print(f"Epoch [{ep+1}/{config['epochs']}] loss : {train_loss}")

model saved to /data/allen/hw4model/longep/backbone2.pth
model saved to /data/allen/hw4model/longep/backbone2_last.pth
checkpoint saved to /data/allen/hw4model/longep/byol2_last.pth
Epoch [1/3000] loss : 1.5127683560053506
model saved to /data/allen/hw4model/longep/backbone2.pth
model saved to /data/allen/hw4model/longep/backbone2_last.pth
checkpoint saved to /data/allen/hw4model/longep/byol2_last.pth
Epoch [2/3000] loss : 1.316165505250295
model saved to /data/allen/hw4model/longep/backbone2.pth
model saved to /data/allen/hw4model/longep/backbone2_last.pth
checkpoint saved to /data/allen/hw4model/longep/byol2_last.pth
Epoch [3/3000] loss : 1.2674558925628663
model saved to /data/allen/hw4model/longep/backbone2.pth
model saved to /data/allen/hw4model/longep/backbone2_last.pth
checkpoint saved to /data/allen/hw4model/longep/byol2_last.pth
Epoch [4/3000] loss : 1.1886670784155529
model saved to /data/allen/hw4model/longep/backbone2.pth
model saved to /data/allen/hw4model/longep/backbone2

## Resume

In [6]:
resnet = models.resnet50(weights=None).to(config['device'])
model_state, opt_state, sch_state, last_ep, best_loss = load_checkpoint(config['last_byol_pth'], config['device'])
learner = BYOL(
    resnet,
    image_size = config['imgsz'],
    hidden_layer = 'avgpool',
    moving_average_decay = 0.9995
)
learner.load_state_dict(model_state)
opt = torch.optim.Adam(learner.parameters(), lr=config['lr'], weight_decay=1.5e-6)
opt.load_state_dict(opt_state)
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(opt, T_max=config['epochs'])
scheduler.load_state_dict(sch_state)
learner.train()
for ep in range(last_ep + 1, config['epochs']):
    train_loss = 0.
    for idx, imgs in enumerate(train_loader):
        imgs = imgs.to(config['device'])
        loss = learner(imgs)
        opt.zero_grad()
        loss.backward()
        opt.step()
        learner.update_moving_average()
        train_loss += loss.item()
    scheduler.step()
    train_loss /= (idx + 1)
    if train_loss < best_loss:
        best_loss = train_loss
        save_model_only(config['best_pth'], resnet)
    #save last
    save_model_only(config['last_pth'], resnet)
    save_checkpoint(config['last_byol_pth'], learner, opt, scheduler, ep, best_loss)
    print(f"Epoch [{ep+1}/{config['epochs']}] loss : {train_loss}")

model saved to /data/allen/hw4model/longep/backbone2_last.pth
checkpoint saved to /data/allen/hw4model/longep/byol2_last.pth
Epoch [1/3000] loss : 0.20300449639558793
model saved to /data/allen/hw4model/longep/backbone2_last.pth
checkpoint saved to /data/allen/hw4model/longep/byol2_last.pth
Epoch [2/3000] loss : 0.18621910278995832
model saved to /data/allen/hw4model/longep/backbone2_last.pth
checkpoint saved to /data/allen/hw4model/longep/byol2_last.pth
Epoch [3/3000] loss : 0.21636226378381251
model saved to /data/allen/hw4model/longep/backbone2_last.pth
checkpoint saved to /data/allen/hw4model/longep/byol2_last.pth
Epoch [4/3000] loss : 0.22580092589060466
model saved to /data/allen/hw4model/longep/backbone2_last.pth
checkpoint saved to /data/allen/hw4model/longep/byol2_last.pth
Epoch [5/3000] loss : 0.22535799977680046
model saved to /data/allen/hw4model/longep/backbone2_last.pth
checkpoint saved to /data/allen/hw4model/longep/byol2_last.pth
Epoch [6/3000] loss : 0.2169123171021540