[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://github.com/CryptoSalamander/pytorch_paper_implementation/blob/master/resnet/resnet_cifar10.ipynb)

In [1]:
import torchvision
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import torchvision.transforms as transforms
from resnet import ResNet18, ResNet34, ResNet50, ResNet101, ResNet152
import os
import torchvision.models as models
import numpy as np

In [2]:
mixup_alpha = 1.0

def mixup_data(x, y):
    lam = np.random.beta(mixup_alpha, mixup_alpha)
    batch_size = x.size()[0]
    index = torch.randperm(batch_size).cuda()
    mixed_x = lam * x + (1 - lam) * x[index]
    y_a, y_b = y, y[index]
    return mixed_x, y_a, y_b, lam

def mixup_criterion(criterion, pred, y_a, y_b, lam):
    return lam * criterion(pred, y_a) + (1-lam) * criterion(pred, y_b)

class LabelSmoothingCrossEntropy(nn.Module):
    def __init__(self):
        super(LabelSmoothingCrossEntropy, self).__init__()
    
    def forward(self, y, targets, smoothing=0.1):
        confidence = 1. - smoothing
        log_probs = F.log_softmax(y, dim=-1)
        true_probs = torch.zeros_like(log_probs)
        true_probs.fill_(smoothing / (y.shape[1] - 1))
        true_probs.scatter_(1, targets.data.unsqueeze(1), confidence)
        return torch.mean(torch.sum(true_probs * -log_probs, dim=-1))

In [3]:
# Simple Learning Rate Scheduler
def lr_scheduler(optimizer, epoch):
    lr = learning_rate
    if epoch >= 50:
        lr /= 10
    if epoch >= 100:
        lr /= 10
    for param_group in optimizer.param_groups:
        param_group['lr'] = lr

# Xavier         
def init_weights(m):
    if isinstance(m, nn.Linear):
        torch.nn.init.xavier_uniform(m.weight)
        m.bias.data.fill_(0.01)

In [4]:
transform_train = transforms.Compose([
    transforms.RandomCrop(32, padding=4),
    transforms.RandomHorizontalFlip(),
    transforms.ToTensor(),
])

transform_test = transforms.Compose([
    transforms.ToTensor(),
])

train_dataset = torchvision.datasets.CIFAR10(root='./data', train=True, download=True, transform=transform_train)
test_dataset = torchvision.datasets.CIFAR10(root='./data', train=False, download=True, transform=transform_test)

train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=128, shuffle=True, num_workers=8)
test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=100, shuffle=False, num_workers=8)

Files already downloaded and verified
Files already downloaded and verified


In [5]:
device = 'cuda'
model = ResNet18()
# ResNet18, ResNet34, ResNet50, ResNet101, ResNet152 중에 택일하여 사용

In [6]:
model.apply(init_weights)
model = model.to(device)

  torch.nn.init.xavier_uniform(m.weight)


In [7]:
learning_rate = 0.1
num_epoch = 150
model_name = 'model.pth'

loss_fn = LabelSmoothingCrossEntropy()
optimizer = optim.SGD(model.parameters(), lr=learning_rate, momentum=0.9, weight_decay=0.0002)

train_loss = 0
valid_loss = 0
correct = 0
total_cnt = 0
best_acc = 0

In [8]:
# Train
for epoch in range(num_epoch):
    print(f"====== { epoch+1} epoch of { num_epoch } ======")
    model.train()
    lr_scheduler(optimizer, epoch)
    train_loss = 0
    valid_loss = 0
    correct = 0
    total_cnt = 0
    # Train Phase
    for step, (inputs, targets) in enumerate(train_loader):
        #  input and target
        inputs, targets = inputs.to(device), targets.to(device)
        inputs, targets_a, targets_b, lam = mixup_data(inputs, targets)
        optimizer.zero_grad()
        
        logits = model(inputs)
        loss = mixup_criterion(loss_fn, logits, targets_a, targets_b, lam)
        loss.backward()
        
        optimizer.step()
        train_loss += loss.item()
        _, predict = logits.max(1)
        
        total_cnt += targets.size(0)
        correct += (lam * predict.eq(targets_a).sum().item() + (1-lam) * predict.eq(targets_b).sum().item())
        
        if step % 100 == 0 and step != 0:
            print(f"\n====== { step } Step of { len(train_loader) } ======")
            print(f"Train Acc : { correct / total_cnt }")
            print(f"Train Loss : { loss.item() / targets.size(0) }")
            
    correct = 0
    total_cnt = 0
    
# Test Phase
    with torch.no_grad():
        model.eval()
        for step, batch in enumerate(test_loader):
            # input and target
            batch[0], batch[1] = batch[0].to(device), batch[1].to(device)
            total_cnt += batch[1].size(0)
            logits = model(batch[0])
            valid_loss += loss_fn(logits, batch[1])
            _, predict = logits.max(1)
            correct += predict.eq(batch[1]).sum().item()
        valid_acc = correct / total_cnt
        print(f"\nValid Acc : { valid_acc }")    
        print(f"Valid Loss : { valid_loss / total_cnt }")

        if(valid_acc > best_acc):
            best_acc = valid_acc
            torch.save(model, model_name)
            print("Model Saved!")


Train Acc : 0.13737284557829862
Train Loss : 0.017385829240083694

Train Acc : 0.16739327568657797
Train Loss : 0.01775568351149559

Train Acc : 0.1901568744619238
Train Loss : 0.017090903595089912

Valid Acc : 0.3388
Valid Loss : 0.019300032407045364
Model Saved!

Train Acc : 0.26608956016122653
Train Loss : 0.014304492622613907

Train Acc : 0.27526716890928626
Train Loss : 0.01670186221599579

Train Acc : 0.28311543668874434
Train Loss : 0.015808356925845146

Valid Acc : 0.4112
Valid Loss : 0.017903869971632957
Model Saved!

Train Acc : 0.32866543672371384
Train Loss : 0.015131600201129913

Train Acc : 0.3279159760185664
Train Loss : 0.016011128202080727

Train Acc : 0.3350894299505749
Train Loss : 0.014554569497704506

Valid Acc : 0.4
Valid Loss : 0.01811722293496132

Train Acc : 0.36307749393514244
Train Loss : 0.013027166947722435

Train Acc : 0.36359003752580776
Train Loss : 0.013207728043198586

Train Acc : 0.36355960550686695
Train Loss : 0.014854035340249538

Valid Acc : 0.33