[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://github.com/CryptoSalamander/pytorch_paper_implementation/blob/master/resnet/resnet_cifar10.ipynb)

In [1]:
import torchvision
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision.transforms as transforms
from resnet import ResNet18, ResNet34, ResNet50, ResNet101, ResNet152
import os
import torchvision.models as models

In [2]:
# Simple Learning Rate Scheduler
def lr_scheduler(optimizer, epoch):
    lr = learning_rate
    if epoch >= 50:
        lr /= 10
    if epoch >= 100:
        lr /= 10
    for param_group in optimizer.param_groups:
        param_group['lr'] = lr

# Xavier         
def init_weights(m):
    if isinstance(m, nn.Linear):
        torch.nn.init.xavier_uniform(m.weight)
        m.bias.data.fill_(0.01)

In [3]:
transform_train = transforms.Compose([
    transforms.RandomCrop(32, padding=4),
    transforms.RandomHorizontalFlip(),
    transforms.ToTensor(),
])

transform_test = transforms.Compose([
    transforms.ToTensor(),
])

train_dataset = torchvision.datasets.CIFAR10(root='./data', train=True, download=True, transform=transform_train)
test_dataset = torchvision.datasets.CIFAR10(root='./data', train=False, download=True, transform=transform_test)

train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=128, shuffle=True, num_workers=8)
test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=100, shuffle=False, num_workers=8)

Files already downloaded and verified
Files already downloaded and verified


In [4]:
device = 'cuda'
model = ResNet18()
# ResNet18, ResNet34, ResNet50, ResNet101, ResNet152 중에 택일하여 사용

In [5]:
model.apply(init_weights)
model = model.to(device)

  torch.nn.init.xavier_uniform(m.weight)


In [6]:
learning_rate = 0.1
num_epoch = 150
model_name = 'model.pth'

loss_fn = nn.CrossEntropyLoss()
optimizer = optim.SGD(model.parameters(), lr=learning_rate, momentum=0.9, weight_decay=0.0002)

train_loss = 0
valid_loss = 0
correct = 0
total_cnt = 0
best_acc = 0

In [7]:
# Train
for epoch in range(num_epoch):
    print(f"====== { epoch+1} epoch of { num_epoch } ======")
    model.train()
    lr_scheduler(optimizer, epoch)
    train_loss = 0
    valid_loss = 0
    correct = 0
    total_cnt = 0
    # Train Phase
    for step, batch in enumerate(train_loader):
        #  input and target
        batch[0], batch[1] = batch[0].to(device), batch[1].to(device)
        optimizer.zero_grad()
        
        logits = model(batch[0])
        loss = loss_fn(logits, batch[1])
        loss.backward()
        
        optimizer.step()
        train_loss += loss.item()
        _, predict = logits.max(1)
        
        total_cnt += batch[1].size(0)
        correct +=  predict.eq(batch[1]).sum().item()
        
        if step % 100 == 0 and step != 0:
            print(f"\n====== { step } Step of { len(train_loader) } ======")
            print(f"Train Acc : { correct / total_cnt }")
            print(f"Train Loss : { loss.item() / batch[1].size(0) }")
            
    correct = 0
    total_cnt = 0
    
# Test Phase
    with torch.no_grad():
        model.eval()
        for step, batch in enumerate(test_loader):
            # input and target
            batch[0], batch[1] = batch[0].to(device), batch[1].to(device)
            total_cnt += batch[1].size(0)
            logits = model(batch[0])
            valid_loss += loss_fn(logits, batch[1])
            _, predict = logits.max(1)
            correct += predict.eq(batch[1]).sum().item()
        valid_acc = correct / total_cnt
        print(f"\nValid Acc : { valid_acc }")    
        print(f"Valid Loss : { valid_loss / total_cnt }")

        if(valid_acc > best_acc):
            best_acc = valid_acc
            torch.save(model, model_name)
            print("Model Saved!")


Train Acc : 0.1938428217821782
Train Loss : 0.01568794623017311

Train Acc : 0.23833955223880596
Train Loss : 0.014690200798213482

Train Acc : 0.2711015365448505
Train Loss : 0.0130706075578928

Valid Acc : 0.3751
Valid Loss : 0.017558705061674118
Model Saved!

Train Acc : 0.41104579207920794
Train Loss : 0.013532757759094238

Train Acc : 0.43023165422885573
Train Loss : 0.010527488775551319

Train Acc : 0.44466362126245845
Train Loss : 0.011626498773694038

Valid Acc : 0.3877
Valid Loss : 0.01789134182035923
Model Saved!

Train Acc : 0.5209622524752475
Train Loss : 0.00938087422400713

Train Acc : 0.5340485074626866
Train Loss : 0.009951438754796982

Train Acc : 0.5465894933554817
Train Loss : 0.007221399340778589

Valid Acc : 0.5959
Valid Loss : 0.011279930360615253
Model Saved!

Train Acc : 0.6154857673267327
Train Loss : 0.008409875445067883

Train Acc : 0.6218128109452736
Train Loss : 0.008512405678629875

Train Acc : 0.6285818106312292
Train Loss : 0.006934578064829111

Valid A