In [None]:
import torch
import torchvision
import random
from model import *

device = 'cuda' if torch.cuda.is_available() else 'cpu'
random.seed(777)
torch.manual_seed(777)
if device == 'cuda':
    torch.cuda.manual_seed(777)

In [None]:
learning_rate = 1e-2
batch_size = 64
epoch = 200

In [None]:
root = "./CIFAR100"
# transform not added now
data_train = torchvision.datasets.tiny_imagenet_datasets.CIFAR100(root, split = "train", download = True)
data_test = torchvision.datasets.tiny_imagenet_datasets.CIFAR100(root, split = "test", download = True)

In [None]:
train_loader = torch.utils.data.DataLoader(dataset=data_train, batch_size=batch_size, shuffle=True, drop_last=True)
test_loader = torch.utils.data.DataLoader(dataset=data_test, batch_size=batch_size, shuffle=False, drop_last=True)

In [None]:
model = ResNet(Bottleneck, [3, 4, 6, 3], num_classes=100).to(device)
criterion = torch.nn.CrossEntropyLoss().to(device)
optimizer = torch.optim.Adam(model.parameters(), lr = learning_rate)

In [None]:
def train_one_epoch(print_result = False):
    model.train()
    loss_sum = 0.0
    accuracy_sum = 0.0
    length = 0

    for X, Y in train_loader:
        X = X.to(device)
        Y = Y.to(device)

        optimizer.zero_grad()

        pred = model(X)
        loss = criterion(pred, Y)
        pred_idx = torch.argmax(pred, 1)
        loss_sum += loss.item()
        accuracy_sum += torch.sum((pred_idx == Y).float())
        length += X.size(1)

        loss.backward()
        loss.step()

    if(print_result):
        print("loss :", loss_sum / length)
        print("accuracy:", accuracy_sum / length)

In [None]:
def eval():
    model.eval()
    loss_sum = 0.0
    accuracy_sum = 0.0
    length = 0

    for X, Y in train_loader:
        X = X.to(device)
        Y = Y.to(device)

        optimizer.zero_grad()

        pred = model(X)
        loss = criterion(pred, Y)
        pred_idx = torch.argmax(pred, 1)
        loss_sum += loss.item()
        accuracy_sum += torch.sum((pred_idx == Y).float())
        length += X.size(1)

        loss.backward()
        loss.step()
        
    print("loss :", loss_sum / length)
    print("accuracy:", accuracy_sum / length)

In [None]:
def train():
    for i in range(epoch):
        print("EPOCH[" + str(i + 1) + "]")
        print("==== train ====")
        train_one_epoch(print_result=True)
        
        print("==== eval ====")
        eval()