In [1]:
import torchvision
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision.transforms as transforms
from resnet import ResNet18
import os

In [2]:
# Simple Learning Rate Scheduler
def lr_scheduler(optimizer, epoch):
    lr = learning_rate
    if epoch >= 50:
        lr /= 10
    if epoch >= 100:
        lr /= 10
    for param_group in optimizer.param_groups:
        param_group['lr'] = lr

# Xavier         
def init_weights(m):
    if isinstance(m, nn.Linear):
        torch.nn.init.xavier_uniform(m.weight)
        m.bias.data.fill_(0.01)

In [3]:
transform_train = transforms.Compose([
    transforms.RandomCrop(32, padding=4),
    transforms.RandomHorizontalFlip(),
    transforms.ToTensor(),
])

transform_test = transforms.Compose([
    transforms.ToTensor(),
])

train_dataset = torchvision.datasets.CIFAR10(root='./data', train=True, download=True, transform=transform_train)
test_dataset = torchvision.datasets.CIFAR10(root='./data', train=False, download=True, transform=transform_test)

train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=128, shuffle=True, num_workers=8)
test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=100, shuffle=False, num_workers=8)

Files already downloaded and verified
Files already downloaded and verified


In [4]:
device = 'cuda'
model = ResNet18()
model.apply(init_weights)
model = model.to(device)

learning_rate = 0.1
num_epoch = 150
model_name = 'model.pth'

loss_fn = nn.CrossEntropyLoss()
optimizer = optim.SGD(model.parameters(), lr=learning_rate, momentum=0.9, weight_decay=0.0001)

train_loss = 0
valid_loss = 0
correct = 0
total_cnt = 0
best_acc = 0

  torch.nn.init.xavier_uniform(m.weight)


In [5]:
# Train
for epoch in range(num_epoch):
    print(f"====== { epoch+1} epoch of { num_epoch } ======")
    model.train()
    lr_scheduler(optimizer, epoch)
    train_loss = 0
    valid_loss = 0
    correct = 0
    total_cnt = 0
    # Train Phase
    for step, batch in enumerate(train_loader):
        #  input and target
        batch[0], batch[1] = batch[0].to(device), batch[1].to(device)
        optimizer.zero_grad()
        
        logits = model(batch[0])
        loss = loss_fn(logits, batch[1])
        loss.backward()
        
        optimizer.step()
        train_loss += loss.item()
        _, predict = logits.max(1)
        
        total_cnt += batch[1].size(0)
        correct +=  predict.eq(batch[1]).sum().item()
        
        if step % 100 == 0 and step != 0:
            print(f"\n====== { step } Step of { len(train_loader) } ======")
            print(f"Train Acc : { correct / total_cnt }")
            print(f"Train Loss : { loss.item() / batch[1].size(0) }")
            
    correct = 0
    total_cnt = 0
    
    # Test Phase
    model.eval()
    for step, batch in enumerate(test_loader):
        # input and target
        batch[0], batch[1] = batch[0].to(device), batch[1].to(device)
        total_cnt += batch[1].size(0)
        logits = model(batch[0])
        valid_loss += loss_fn(logits, batch[1])
        _, predict = logits.max(1)
        correct += predict.eq(batch[1]).sum().item()
    valid_acc = correct / total_cnt
    print(f"\nValid Acc : { valid_acc }")    
    print(f"Valid Loss : { valid_loss / total_cnt }")
    
    if(valid_acc > best_acc):
        best_acc = valid_acc
        torch.save(model, model_name)
        print("Model Saved!")



  return torch.max_pool2d(input, kernel_size, stride, padding, dilation, ceil_mode)



Train Acc : 0.15594059405940594
Train Loss : 0.016624178737401962

Train Acc : 0.18765547263681592
Train Loss : 0.018544675782322884

Train Acc : 0.2237074335548173
Train Loss : 0.014839627780020237

Valid Acc : 0.3685
Valid Loss : 0.020532814785838127
Model Saved!

Train Acc : 0.359220297029703
Train Loss : 0.013369300402700901

Train Acc : 0.37779850746268656
Train Loss : 0.01183394156396389

Train Acc : 0.3895608388704319
Train Loss : 0.012920999899506569

Valid Acc : 0.4677
Valid Loss : 0.014703524298965931
Model Saved!

Train Acc : 0.44129022277227725
Train Loss : 0.010599120520055294

Train Acc : 0.44748911691542287
Train Loss : 0.01061832346022129

Train Acc : 0.4534624169435216
Train Loss : 0.010914064012467861

Valid Acc : 0.3775
Valid Loss : 0.01807265542447567

Train Acc : 0.5095142326732673
Train Loss : 0.011263824999332428

Train Acc : 0.5133706467661692
Train Loss : 0.009082668460905552

Train Acc : 0.5167151162790697
Train Loss : 0.011405550874769688

Valid Acc : 0.5158