In [1]:
import argparse
import os
import random
import sys
import time
import warnings

import models
import numpy as np
import torch
import torch.backends.cudnn as cudnn
import torch.multiprocessing as mp
import torch.nn as nn
import torch.nn.parallel
import torch.optim
import torch.utils.data
import torchvision.datasets as datasets
import torchvision.transforms as transforms
from imbalance_cifar import IMBALANCECIFAR10, IMBALANCECIFAR100
from losses import FocalLoss, LDAMLoss
from sklearn.metrics import confusion_matrix
from tensorboardX import SummaryWriter
from utils import *

In [2]:
dataset = 'cifar10'
arch = 'resnet32'
loss_type = 'CE'
train_rule = 'None'
imb_type = 'exp'
imb_factor = 0.01
rand_number = 0
exp_str = 'exp01'
workers = 8
epoch = 1
batch_size = 128
lr = 0.1
momentum = 1
weight_decay = 2e-4
print_freq = 10
resume = ''
root_log = 'log'
root_model = 'checkpoint'
rot_ratio = 0.1

best_acc1 = 0

In [3]:
store_name = "_".join(
        [
            dataset,
            arch,
            loss_type,
            train_rule,
            imb_type,
            str(imb_factor),
            exp_str,
        ]
    )
root_log = root_log + "/" + str(int(rot_ratio * 100))
folders = {
    'root_log': root_log,
    'root_model': root_model,
    'store_name': store_name
}

In [4]:
prepare_folders(folders, False)

creating folder checkpoint
creating folder log/10/cifar10_resnet32_CE_None_exp_0.01_exp01
creating folder checkpoint/cifar10_resnet32_CE_None_exp_0.01_exp01


create model

In [5]:
num_classes = 100 if dataset == "cifar100" else 10
use_norm = True if loss_type == "LDAM" else False
model = models.__dict__[arch](num_classes=num_classes, use_norm=use_norm)
model = torch.nn.DataParallel(model).cuda()

num_trans : 16


In [6]:
optimizer = torch.optim.SGD(
        model.parameters(),
        lr,
        momentum=momentum,
        weight_decay=weight_decay,
    )

In [7]:
cudnn.benchmark = True

Data loading

In [8]:
transform_train = transforms.Compose(
    [
        transforms.RandomCrop(32, padding=4),
        transforms.RandomHorizontalFlip(),
        transforms.ToTensor(),
        transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
    ]
)

transform_val = transforms.Compose(
    [
        transforms.ToTensor(),
        transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
    ]
)

In [9]:

if dataset == "cifar10":
    train_dataset = IMBALANCECIFAR10(
        root="./data",
        imb_type=imb_type,
        imb_factor=imb_factor,
        rand_number=rand_number,
        train=True,
        download=True,
        transform=transform_train,
    )
    val_dataset = datasets.CIFAR10(
        root="./data", train=False, download=True, transform=transform_val
    )
elif dataset == "cifar100":
    train_dataset = IMBALANCECIFAR100(
        root="./data",
        imb_type=imb_type,
        imb_factor=imb_factor,
        rand_number=rand_number,
        train=True,
        download=True,
        transform=transform_train,
    )
    val_dataset = datasets.CIFAR100(
        root="./data", train=False, download=True, transform=transform_val
    )
else:
        warnings.warn("Dataset is not listed")

Files already downloaded and verified
Files already downloaded and verified


In [10]:
cls_num_list = train_dataset.get_cls_num_list()
print("cls num list:")
print(cls_num_list)
cls_num_list = cls_num_list

train_sampler = None

train_loader = torch.utils.data.DataLoader(
    train_dataset,
    batch_size=batch_size,
    shuffle=(train_sampler is None),
    num_workers=workers,
    pin_memory=True,
    sampler=train_sampler,
)

val_loader = torch.utils.data.DataLoader(
    val_dataset,
    batch_size=100,
    shuffle=False,
    num_workers=workers,
    pin_memory=True,
)

cls num list:
[5000, 2997, 1796, 1077, 645, 387, 232, 139, 83, 50]


log training model

In [None]:
log_training = open(
    os.path.join(root_log, store_name, "log_train.csv"), "w"
)
log_testing = open(
    os.path.join(root_log, args.store_name, "log_test.csv"), "w"
)
with open(os.path.join(args.root_log, args.store_name, "args.txt"), "w") as f:
    f.write(str(args))
tf_writer = SummaryWriter(log_dir=os.path.join(args.root_log, args.store_name))