In [None]:
import torch
import time
import wandb
import os

from torchvision import transforms

from resnet import resnet18
from utils import AverageMeter, accuracy, add_to_confusion_matrix, get_per_class_results, make_deterministic, save_ckpt, load_ckpt

use_pretrained = True
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
ROOT_DIR = os.getcwd()
train_dataset_path = os.path.join(ROOT_DIR, "posco_data/places10_LT/train")
valid_dataset_path = os.path.join(ROOT_DIR,"posco_data/places10_LT/valid")

batch_size = 64
total_epochs = 30
lr_steps = [10, 20, 25]
turn_on_wandb = True
run_name = "resnet18_places10_LT_undersampling"

make_deterministic(random_seed=42)

In [None]:
import random
from randaugment import rand_augment_transform, GaussianBlur_simclr
from torchvision.datasets import ImageFolder

class CustomDataset(torch.utils.data.Dataset):
    def __init__(self, dataset_path, transform, use_randaug=False):
        super(CustomDataset, self).__init__()
        self.dataset_path = dataset_path
        self.transform = transform
        self.use_randaug = use_randaug
        if self.use_randaug:
            rgb_mean = (0.485, 0.456, 0.406)
            ra_params = dict(translate_const=int(224 * 0.45), img_mean=tuple([min(255, round(255 * x)) for x in rgb_mean]), )
            normalize = self.transform.transforms[-1] # get normalize layer
            self.aug1 = transforms.Compose([
                transforms.RandomResizedCrop(224, scale=(0.08, 1.)),
                transforms.RandomHorizontalFlip(),
                transforms.RandomApply([
                    transforms.ColorJitter(0.4, 0.4, 0.4, 0.0)
                ], p=1.0),
                rand_augment_transform('rand-n{}-m{}-mstd0.5'.format(2, 10), ra_params),
                transforms.ToTensor(),
                normalize
            ])
            self.aug2 = transforms.Compose([
                transforms.RandomResizedCrop(224),
                transforms.RandomApply([
                    transforms.ColorJitter(0.4, 0.4, 0.4, 0.0)  # not strengthened
                ], p=1.0),
                transforms.RandomGrayscale(p=0.2),
                transforms.RandomApply([GaussianBlur_simclr([.1, 2.])], p=0.5),
                transforms.RandomHorizontalFlip(),
                transforms.ToTensor(),
                normalize
            ])
        print("Loading dataset...")
        self.load_dataset()

    def do_under_sampling(self, data):
        indices = []
        min_count = min([data.targets.count(i) for i in range(len(data.classes))])
        for i in range(len(data.classes)):
            target_indices = [j for j, x in enumerate(data.targets) if x == i]
            indices += target_indices[:min_count]
        return indices

    def load_dataset(self):
        self.data = ImageFolder(self.dataset_path)
        
        # TODO UNDERSAMPLING CODE
        self.data = torch.utils.data.Subset(self.data, self.do_under_sampling(self.data))
        
    def do_transform(self, img):
        if self.use_randaug:
            r = random.random()
            if r < 0.5:
                img = self.aug1(img)
            else:
                img = self.aug2(img)
        else:
            img = self.transform(img)
        return img

    def __getitem__(self, index):
        img, label = self.data[index]
        img = self.do_transform(img)
        return img, label

    def __len__(self):
        return len(self.data)

In [None]:
# Build Model
if use_pretrained:
    model = resnet18(pretrained=True, num_classes=1000).to(device)
    model.fc = torch.nn.Linear(model.fc.in_features, 10).to(device)
else:
    model = resnet18(num_classes=10).to(device)


# Make Dataset & DataLoader
train_transforms = transforms.Compose([
    transforms.RandomResizedCrop(224),
    transforms.RandomHorizontalFlip(),
    transforms.ColorJitter(brightness=0.4, contrast=0.4, saturation=0.4, hue=0),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])
valid_transforms = transforms.Compose([
    transforms.Resize(256),
    transforms.CenterCrop(224),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])
train_dataset = CustomDataset(dataset_path=train_dataset_path, transform=train_transforms, use_randaug=False)
valid_dataset = CustomDataset(dataset_path=valid_dataset_path, transform=valid_transforms, use_randaug=False)

train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=8, pin_memory=True)
valid_loader = torch.utils.data.DataLoader(valid_dataset, batch_size=batch_size, shuffle=False, num_workers=8, pin_memory=True)
print("Train dataset length:", len(train_dataset))
print("Valid dataset length:", len(valid_dataset))

# Make Optimizer & Loss
optimizer = torch.optim.SGD(model.parameters(), 1e-2, momentum=0.9, weight_decay=2e-4)
scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer, milestones=lr_steps, gamma=0.1)
criterion = torch.nn.CrossEntropyLoss().to(device)

In [None]:
def train(device, train_loader, model, criterion, optimizer, scheduler, epoch, turn_on_wandb=False):
    start_time = time.time()

    losses, top1, top5 = AverageMeter(device), AverageMeter(device), AverageMeter(device)
    model.train()
    for i, (images, labels) in enumerate(train_loader):
        images = images.to(device)
        labels = labels.to(device)
        outputs = model(images)
        loss = criterion(outputs, labels)
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        prec1, prec5 = accuracy(outputs, labels , topk=(1, 5))
        losses.update(loss.item(), images.size(0))
        top1.update(prec1.item(), images.size(0))
        top5.update(prec5.item(), images.size(0))

    end_time = time.time()
    scheduler.step()

    print(f"==================== Train Summary: Epoch {epoch+1} ====================", flush=True)
    print(f"Train Epoch Elapsed time: {time.strftime('%H:%M:%S', time.gmtime(end_time - start_time))}", flush=True)
    print(f"Loss: {losses.avg:.2f}\t Prec@1: {top1.avg:.2f}\t Prec@5: {top5.avg:.2f}", flush=True)
    if turn_on_wandb:
        wandb.log({"train/loss": losses.avg, "train/top1": top1.avg, "train/top5": top5.avg}, step=epoch+1)


def validate(device, valid_loader, model, criterion, epoch, turn_on_wandb=False):
    start_time = time.time()

    losses, top1, top5 = AverageMeter(device), AverageMeter(device), AverageMeter(device)
    model.eval()
    confusion_matrix = torch.zeros(10, 10).to(device)

    with torch.no_grad():
        for i, (images, labels) in enumerate(valid_loader):
            images = images.to(device)
            labels = labels.to(device)
            outputs = model(images)
            loss = criterion(outputs, labels)
            prec1, prec5 = accuracy(outputs, labels, topk=(1, 5))
            confusion_matrix = add_to_confusion_matrix(confusion_matrix, outputs, labels)
            losses.update(loss.item(), images.size(0))
            top1.update(prec1.item(), images.size(0))
            top5.update(prec5.item(), images.size(0))

    end_time = time.time()
    per_class_results = get_per_class_results(confusion_matrix)
    print(f"==================== Valid Summary: Epoch {epoch+1} ====================", flush=True)
    print(f"Valid Elapsed time: {time.strftime('%H:%M:%S', time.gmtime(end_time - start_time))}", flush=True)
    print(f"Loss: {losses.avg:.2f}\t Prec@1: {top1.avg:.2f}\t Prec@5: {top5.avg:.2f}", flush=True)
    if turn_on_wandb:
        wandb.log({"valid/loss": losses.avg, "valid/top1": top1.avg, "valid/top5": top5.avg}, step=epoch+1)
    return top1.avg, per_class_results

In [None]:
if turn_on_wandb:
    wandb.init(project="posco2023", name=run_name)

# Main Loop
best_top1, best_top1_epoch, best_per_class_results = 0, 0, None
for epoch in range(total_epochs):
    train(device, train_loader, model, criterion, optimizer, scheduler, epoch, turn_on_wandb=True)
    top1, per_class_results = validate(device, valid_loader, model, criterion, epoch, turn_on_wandb=True)
    if top1 > best_top1:
        best_top1 = top1
        best_top1_epoch = epoch+1
        best_per_class_results = per_class_results
        save_ckpt(epoch=epoch+1, model=model, per_class_results=per_class_results, run_name=run_name)
        
    print(f"Best Prec@1: {best_top1:.2f} at epoch {best_top1_epoch}", flush=True)

# Print Best Results
print(f"Best per class results: {best_per_class_results}", flush=True)

if turn_on_wandb:
    wandb.finish()

In [None]:
# Loading and testing
model_for_eval = resnet18(num_classes=10).to(device)
load_ckpt(model=model_for_eval, run_name=run_name)
validate(device, valid_loader, model_for_eval, criterion, epoch=0, turn_on_wandb=False)
