In [None]:
import torch
import time
import wandb

from torchvision import transforms

from resnet import resnet18
from utils import BalancedSoftmax, CustomDataset, AverageMeter, accuracy, add_to_confusion_matrix, get_per_class_results, make_deterministic, save_ckpt, load_ckpt


use_pretrained = True
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
ROOT_DIR = os.getcwd()
train_dataset_path = os.path.join(ROOT_DIR, "posco_data/places10_LT/train")
valid_dataset_path = os.path.join(ROOT_DIR,"posco_data/places10_LT/valid")

batch_size = 64
total_epochs = 30
lr_steps = [10, 20, 25]
turn_on_wandb = False
run_name = "resnet18_places10_LT_balanced_softmax"

make_deterministic(random_seed=42)

In [None]:
from torch.nn.modules.loss import _Loss
import torch.nn.functional as F

class BalancedSoftmax(_Loss):
    def __init__(self, samples_per_class):
        super(BalancedSoftmax, self).__init__()
        self.sample_per_class = torch.tensor(samples_per_class)

    def balanced_softmax_loss(self, labels, logits, sample_per_class, reduction="mean"):
        spc = sample_per_class.type_as(logits)  # (num_classes, )
        spc = spc.unsqueeze(0).expand(logits.shape[0], -1) # (batch_size, num_classes)
        logits = logits + spc.log() # (batch_size, num_classes)
        loss = F.cross_entropy(input=logits, target=labels, reduction=reduction) # (batch_size, )
        return loss

    def forward(self, input, label, reduction='mean'):
        return self.balanced_softmax_loss(label, input, self.sample_per_class, reduction)


In [None]:
# Build Model
if use_pretrained:
    model = resnet18(pretrained=True, num_classes=1000).to(device)
    model.fc = torch.nn.Linear(model.fc.in_features, 10).to(device)
else:
    model = resnet18(num_classes=10).to(device)


# Make Dataset & DataLoader
train_transforms = transforms.Compose([
    transforms.RandomResizedCrop(224),
    transforms.RandomHorizontalFlip(),
    transforms.ColorJitter(brightness=0.4, contrast=0.4, saturation=0.4, hue=0),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])
valid_transforms = transforms.Compose([
    transforms.Resize(256),
    transforms.CenterCrop(224),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])
train_dataset = CustomDataset(dataset_path=train_dataset_path, transform=train_transforms, use_randaug=False)
valid_dataset = CustomDataset(dataset_path=valid_dataset_path, transform=valid_transforms, use_randaug=False)
train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=8, pin_memory=True)
valid_loader = torch.utils.data.DataLoader(valid_dataset, batch_size=batch_size, shuffle=False, num_workers=8, pin_memory=True)

# Make Optimizer & Loss
optimizer = torch.optim.SGD(model.parameters(), 1e-2, momentum=0.9, weight_decay=2e-4)
scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer, milestones=lr_steps, gamma=0.1)
criterion = BalancedSoftmax(samples_per_class=[1000, 555, 308, 170, 94, 52, 29, 16, 9, 5])

In [None]:
def train(device, train_loader, model, criterion, optimizer, scheduler, epoch, turn_on_wandb=False):
    start_time = time.time()

    losses, top1, top5 = AverageMeter(device), AverageMeter(device), AverageMeter(device)
    model.train()
    for i, (images, labels) in enumerate(train_loader):
        images = images.to(device)
        labels = labels.to(device)
        outputs = model(images)
        loss = criterion(outputs, labels)
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        prec1, prec5 = accuracy(outputs, labels , topk=(1, 5))
        losses.update(loss.item(), images.size(0))
        top1.update(prec1.item(), images.size(0))
        top5.update(prec5.item(), images.size(0))

    end_time = time.time()
    scheduler.step()

    print(f"==================== Train Summary: Epoch {epoch+1} ====================", flush=True)
    print(f"Train Epoch Elapsed time: {time.strftime('%H:%M:%S', time.gmtime(end_time - start_time))}", flush=True)
    print(f"Loss: {losses.avg:.2f}\t Prec@1: {top1.avg:.2f}\t Prec@5: {top5.avg:.2f}", flush=True)
    if turn_on_wandb:
        wandb.log({"train/loss": losses.avg, "train/top1": top1.avg, "train/top5": top5.avg}, step=epoch+1)


def validate(device, valid_loader, model, criterion, epoch, turn_on_wandb=False):
    start_time = time.time()

    losses, top1, top5 = AverageMeter(device), AverageMeter(device), AverageMeter(device)
    model.eval()
    confusion_matrix = torch.zeros(10, 10).to(device)

    with torch.no_grad():
        for i, (images, labels) in enumerate(valid_loader):
            images = images.to(device)
            labels = labels.to(device)
            outputs = model(images)
            loss = criterion(outputs, labels)
            prec1, prec5 = accuracy(outputs, labels, topk=(1, 5))
            confusion_matrix = add_to_confusion_matrix(confusion_matrix, outputs, labels)
            losses.update(loss.item(), images.size(0))
            top1.update(prec1.item(), images.size(0))
            top5.update(prec5.item(), images.size(0))

    end_time = time.time()
    per_class_results = get_per_class_results(confusion_matrix)
    print(f"==================== Valid Summary: Epoch {epoch+1} ====================", flush=True)
    print(f"Valid Elapsed time: {time.strftime('%H:%M:%S', time.gmtime(end_time - start_time))}", flush=True)
    print(f"Loss: {losses.avg:.2f}\t Prec@1: {top1.avg:.2f}\t Prec@5: {top5.avg:.2f}", flush=True)
    if turn_on_wandb:
        wandb.log({"valid/loss": losses.avg, "valid/top1": top1.avg, "valid/top5": top5.avg}, step=epoch+1)
    return top1.avg, per_class_results

In [None]:
wandb.init(project="posco2023", entity="alex4727", name=run_name)

# Main Loop
best_top1, best_top1_epoch, best_per_class_results = 0, 0, None
for epoch in range(total_epochs):
    train(device, train_loader, model, criterion, optimizer, scheduler, epoch, turn_on_wandb=True)
    top1, per_class_results = validate(device, valid_loader, model, criterion, epoch, turn_on_wandb=True)
    if top1 > best_top1:
        best_top1 = top1
        best_top1_epoch = epoch+1
        best_per_class_results = per_class_results
        save_ckpt(epoch=epoch+1, model=model, per_class_results=per_class_results, run_name=run_name)
        
    print(f"Best Prec@1: {best_top1:.2f} at epoch {best_top1_epoch}", flush=True)

# Print Best Results
print(f"Best per class results: {best_per_class_results}", flush=True)

wandb.finish()

In [None]:
# Loading and testing
model_for_eval = resnet18(num_classes=10).to(device)
load_ckpt(model=model_for_eval, run_name=run_name)
validate(device, valid_loader, model_for_eval, criterion, epoch=0, turn_on_wandb=False)
