<a href="https://colab.research.google.com/github/KangJuSeong/classification_model_cifar100/blob/main/torch_livecoding_cifar100.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
import pandas as pd
from tqdm import tqdm

import torch
import torchvision
import torch.nn as nn
import torch.optim as optim
import torchvision.transforms as transforms


# 코드 다시 돌리기 위한 seed 고정
import random
import numpy as np
random.seed(0)
np.random.seed(0)
torch.manual_seed(0)
torch.cuda.manual_seed(0)
torch.cuda.manual_seed_all(0)

In [None]:
# layer 함수를 만들고 in_chanel과 out_chanel이 들어왔을때 3x3 conv 층을 2층으로 만들고 padding을 1 줘서 사이즈 유지
# batchnorm을 이용하여 공분산 시프트 현상제거
# activation ReLU 사용
# 마지막에 pooling으로 사이즈 감소
def layer(in_chanel, out_chanel):
    return nn.Sequential(
        nn.Conv2d(in_chanel, out_chanel//2, kernel_size=3, padding=1),
        nn.BatchNorm2d(out_chanel//2),
        nn.ReLU(inplace=True),
        nn.Conv2d(out_chanel//2, out_chanel, kernel_size=3, padding=1),
        nn.BatchNorm2d(out_chanel),
        nn.ReLU(inplace=True),
        nn.MaxPool2d(kernel_size=2)
    )
# 잔차학습을 위한 block층 생성
# 3x3 conv 2개를 쌓고 padding을 1로 줘서 사이즈 유지
# 이후에 해당 블록층 이후에 위에서 정의한 layer 함수의 출력값을 추가로 더해줌(잔차 더해주기)
def block(in_chanel, out_chanel):
    return nn.Sequential(
        nn.Conv2d(in_chanel, out_chanel, kernel_size=3, padding=1),
        nn.BatchNorm2d(out_chanel),
        nn.ReLU(inplace=True),
        nn.Conv2d(out_chanel, out_chanel, kernel_size=3, padding=1),
        nn.BatchNorm2d(out_chanel),
        nn.ReLU(inplace=True)
    )

class MyModel(nn.Module):
    def __init__(self, in_channels, num_classes):
        super(MyModel, self).__init__()
        
        self.layer1 = layer(in_channels, 64) # 64x16x16
        self.block1 = block(64, 64)

        self.layer2 = layer(64, 256) # 256x8x8
        self.block2 = block(256, 256)
        
        self.layer3 = layer(256, 1024) # 1024x4x4
        self.block3 = block(1024, 1024)

        self.classifier = nn.Sequential(
            # 평균 풀링(최대 풀링 시도)로 최종 사이즈를 1x1로 만들어주기
            nn.AvgPool2d(kernel_size=4), # 1024x1x1
            nn.Flatten(),
            # Linear 층 추가
            nn.Linear(1024, 512),
            nn.ReLU(inplace=True),
            nn.Linear(512, num_classes)
        )
        
    def forward(self, x):
        # 각 layer를 거치고 이후에 해당 layer의 출력값을 이후 block층을 거친 출력값에 더해주기
        # 해당 과정을 3번 반복
        x = self.layer1(x)
        x = self.block1(x) + x

        x = self.layer2(x)
        x = self.block2(x) + x

        x = self.layer3(x)
        x = self.block3(x) + x
  
        x = self.classifier(x)
        return x        

In [None]:
model = MyModel(3, 100).cuda()
model

MyModel(
  (layer1): Sequential(
    (0): Conv2d(3, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (1): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (2): ReLU(inplace=True)
    (3): Conv2d(32, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (4): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (5): ReLU(inplace=True)
    (6): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
  )
  (block1): Sequential(
    (0): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (2): ReLU(inplace=True)
    (3): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (4): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (5): ReLU(inplace=True)
  )
  (layer2): Sequential(
    (0): Conv2d(64, 128, kernel_size=(3, 3), stride=(1, 1), 

In [None]:
# 학습 데이터의 평균과 분산 구하기
data = torchvision.datasets.CIFAR100(root="./", train=True, download=True)
x = np.concatenate([np.asarray(data[i][0]) for i in range(len(data))])
mean = np.mean(x, axis=(0, 1))/255
std = np.std(x, axis=(0, 1))/255

# 위에서 구한 평균과 분산을 이용하여 정규화
# 무작위로 확대, padding=4
# 무작위로 수평 반전
train_transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize(mean, std),
    transforms.RandomCrop(32, padding=4),
    transforms.RandomHorizontalFlip()
])        
test_transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize(mean, std)
])

train = torchvision.datasets.CIFAR100(root="./", train=True, download=True, transform=train_transform)
test = torchvision.datasets.CIFAR100(root="./", train=False, download=True, transform=test_transform)

# 배치사이즈 128 시도, 256 시도
train_loader = torch.utils.data.DataLoader(train, batch_size=256,
                                           shuffle=True, num_workers=2)
test_loader = torch.utils.data.DataLoader(test, batch_size=256,
                                          shuffle=False, num_workers=2)

# lr=0.1 or lr=0.01 시도
# momentum 0.9 추가, weight_decay 1e-4 추가
optimizer =  optim.SGD(model.parameters(), lr=0.01, momentum=0.9, weight_decay=1e-4)
criterion = nn.CrossEntropyLoss()
# 스케줄러는 milestones에 있는 값이 진행중인 에포크와 같을 때 기존 lr에 gamma를 곱해주는 MultiStepLR 사용
# epoch 80에서는 0.005, 150에서는 0.00025
scheduler = optim.lr_scheduler.MultiStepLR(optimizer, milestones=[80, 150], gamma=0.5)

Files already downloaded and verified
Files already downloaded and verified
Files already downloaded and verified


In [None]:
for epoch in range(200):
    model.train()
    for img, label in tqdm(train_loader):
        img = img.cuda()
        label = label.cuda()

        optimizer.zero_grad()
        output = model(img)
        loss = criterion(output, label)
        loss.backward()
        optimizer.step()
    correct, all_data = 0,0
    scheduler.step()
    model.eval()
    for img, label in test_loader:
        with torch.no_grad():
            img = img.cuda()
            label = label.cuda()
            output = model(img)

            correct += torch.sum(torch.argmax(output, dim=1) == label).item()
            all_data += len(label)
    torch.save(model.state_dict(), f'/content/models/{correct/all_data*100}.pth')
    print(f"epoch : {epoch+1} acc : {correct/all_data}")

100%|██████████| 196/196 [00:44<00:00,  4.38it/s]


epoch : 1 acc : 0.1539


100%|██████████| 196/196 [00:41<00:00,  4.68it/s]


epoch : 2 acc : 0.2673


100%|██████████| 196/196 [00:41<00:00,  4.68it/s]


epoch : 3 acc : 0.3529


100%|██████████| 196/196 [00:42<00:00,  4.66it/s]


epoch : 4 acc : 0.3354


100%|██████████| 196/196 [00:42<00:00,  4.66it/s]


epoch : 5 acc : 0.4401


100%|██████████| 196/196 [00:42<00:00,  4.64it/s]


epoch : 6 acc : 0.414


100%|██████████| 196/196 [00:41<00:00,  4.67it/s]


epoch : 7 acc : 0.4356


100%|██████████| 196/196 [00:42<00:00,  4.65it/s]


epoch : 8 acc : 0.5003


100%|██████████| 196/196 [00:42<00:00,  4.64it/s]


epoch : 9 acc : 0.5287


100%|██████████| 196/196 [00:41<00:00,  4.67it/s]


epoch : 10 acc : 0.551


100%|██████████| 196/196 [00:41<00:00,  4.67it/s]


epoch : 11 acc : 0.4974


100%|██████████| 196/196 [00:41<00:00,  4.67it/s]


epoch : 12 acc : 0.538


100%|██████████| 196/196 [00:41<00:00,  4.67it/s]


epoch : 13 acc : 0.5306


100%|██████████| 196/196 [00:41<00:00,  4.67it/s]


epoch : 14 acc : 0.5545


100%|██████████| 196/196 [00:42<00:00,  4.65it/s]


epoch : 15 acc : 0.5779


100%|██████████| 196/196 [00:42<00:00,  4.66it/s]


epoch : 16 acc : 0.6057


100%|██████████| 196/196 [00:42<00:00,  4.65it/s]


epoch : 17 acc : 0.5847


100%|██████████| 196/196 [00:41<00:00,  4.67it/s]


epoch : 18 acc : 0.5944


100%|██████████| 196/196 [00:42<00:00,  4.66it/s]


epoch : 19 acc : 0.6017


100%|██████████| 196/196 [00:42<00:00,  4.66it/s]


epoch : 20 acc : 0.6065


100%|██████████| 196/196 [00:42<00:00,  4.62it/s]


epoch : 21 acc : 0.6145


100%|██████████| 196/196 [00:42<00:00,  4.65it/s]


epoch : 22 acc : 0.6266


100%|██████████| 196/196 [00:42<00:00,  4.65it/s]


epoch : 23 acc : 0.6002


100%|██████████| 196/196 [00:41<00:00,  4.68it/s]


epoch : 24 acc : 0.5991


100%|██████████| 196/196 [00:41<00:00,  4.68it/s]


epoch : 25 acc : 0.6052


100%|██████████| 196/196 [00:41<00:00,  4.67it/s]


epoch : 26 acc : 0.611


100%|██████████| 196/196 [00:41<00:00,  4.68it/s]


epoch : 27 acc : 0.6331


100%|██████████| 196/196 [00:41<00:00,  4.68it/s]


epoch : 28 acc : 0.5944


100%|██████████| 196/196 [00:42<00:00,  4.67it/s]


epoch : 29 acc : 0.6252


100%|██████████| 196/196 [00:41<00:00,  4.67it/s]


epoch : 30 acc : 0.6053


100%|██████████| 196/196 [00:41<00:00,  4.67it/s]


epoch : 31 acc : 0.6111


100%|██████████| 196/196 [00:41<00:00,  4.68it/s]


epoch : 32 acc : 0.6387


100%|██████████| 196/196 [00:41<00:00,  4.68it/s]


epoch : 33 acc : 0.6293


100%|██████████| 196/196 [00:41<00:00,  4.67it/s]


epoch : 34 acc : 0.65


100%|██████████| 196/196 [00:41<00:00,  4.67it/s]


epoch : 35 acc : 0.6353


100%|██████████| 196/196 [00:41<00:00,  4.68it/s]


epoch : 36 acc : 0.6418


100%|██████████| 196/196 [00:41<00:00,  4.68it/s]


epoch : 37 acc : 0.6526


100%|██████████| 196/196 [00:41<00:00,  4.67it/s]


epoch : 38 acc : 0.6566


100%|██████████| 196/196 [00:41<00:00,  4.68it/s]


epoch : 39 acc : 0.6446


100%|██████████| 196/196 [00:41<00:00,  4.67it/s]


epoch : 40 acc : 0.6464


100%|██████████| 196/196 [00:41<00:00,  4.68it/s]


epoch : 41 acc : 0.6605


100%|██████████| 196/196 [00:41<00:00,  4.68it/s]


epoch : 42 acc : 0.6669


100%|██████████| 196/196 [00:41<00:00,  4.68it/s]


epoch : 43 acc : 0.6531


100%|██████████| 196/196 [00:41<00:00,  4.67it/s]


epoch : 44 acc : 0.6781


100%|██████████| 196/196 [00:41<00:00,  4.67it/s]


epoch : 45 acc : 0.6636


100%|██████████| 196/196 [00:41<00:00,  4.68it/s]


epoch : 46 acc : 0.6743


100%|██████████| 196/196 [00:41<00:00,  4.67it/s]


epoch : 47 acc : 0.6548


100%|██████████| 196/196 [00:41<00:00,  4.67it/s]


epoch : 48 acc : 0.6593


100%|██████████| 196/196 [00:41<00:00,  4.67it/s]


epoch : 49 acc : 0.6604


100%|██████████| 196/196 [00:41<00:00,  4.68it/s]


epoch : 50 acc : 0.6607


100%|██████████| 196/196 [00:41<00:00,  4.68it/s]


epoch : 51 acc : 0.68


100%|██████████| 196/196 [00:41<00:00,  4.68it/s]


epoch : 52 acc : 0.6749


100%|██████████| 196/196 [00:41<00:00,  4.67it/s]


epoch : 53 acc : 0.6814


100%|██████████| 196/196 [00:41<00:00,  4.68it/s]


epoch : 54 acc : 0.6821


100%|██████████| 196/196 [00:41<00:00,  4.68it/s]


epoch : 55 acc : 0.6672


100%|██████████| 196/196 [00:41<00:00,  4.67it/s]


epoch : 56 acc : 0.6773


100%|██████████| 196/196 [00:41<00:00,  4.67it/s]


epoch : 57 acc : 0.6939


100%|██████████| 196/196 [00:41<00:00,  4.68it/s]


epoch : 58 acc : 0.6785


100%|██████████| 196/196 [00:41<00:00,  4.67it/s]


epoch : 59 acc : 0.6887


100%|██████████| 196/196 [00:41<00:00,  4.68it/s]


epoch : 60 acc : 0.6954


100%|██████████| 196/196 [00:41<00:00,  4.68it/s]


epoch : 61 acc : 0.6802


100%|██████████| 196/196 [00:41<00:00,  4.67it/s]


epoch : 62 acc : 0.6785


100%|██████████| 196/196 [00:41<00:00,  4.68it/s]


epoch : 63 acc : 0.6882


100%|██████████| 196/196 [00:41<00:00,  4.68it/s]


epoch : 64 acc : 0.6805


100%|██████████| 196/196 [00:41<00:00,  4.68it/s]


epoch : 65 acc : 0.6909


100%|██████████| 196/196 [00:41<00:00,  4.68it/s]


epoch : 66 acc : 0.689


100%|██████████| 196/196 [00:41<00:00,  4.68it/s]


epoch : 67 acc : 0.6868


100%|██████████| 196/196 [00:41<00:00,  4.67it/s]


epoch : 68 acc : 0.685


100%|██████████| 196/196 [00:41<00:00,  4.68it/s]


epoch : 69 acc : 0.6968


100%|██████████| 196/196 [00:41<00:00,  4.67it/s]


epoch : 70 acc : 0.6856


100%|██████████| 196/196 [00:41<00:00,  4.68it/s]


epoch : 71 acc : 0.6874


100%|██████████| 196/196 [00:41<00:00,  4.67it/s]


epoch : 72 acc : 0.6939


100%|██████████| 196/196 [00:41<00:00,  4.68it/s]


epoch : 73 acc : 0.7021


100%|██████████| 196/196 [00:42<00:00,  4.66it/s]


epoch : 74 acc : 0.6971


100%|██████████| 196/196 [00:41<00:00,  4.68it/s]


epoch : 75 acc : 0.7023


100%|██████████| 196/196 [00:41<00:00,  4.67it/s]


epoch : 76 acc : 0.6986


100%|██████████| 196/196 [00:41<00:00,  4.67it/s]


epoch : 77 acc : 0.6959


100%|██████████| 196/196 [00:41<00:00,  4.68it/s]


epoch : 78 acc : 0.6993


100%|██████████| 196/196 [00:42<00:00,  4.64it/s]


epoch : 79 acc : 0.7012


100%|██████████| 196/196 [00:41<00:00,  4.68it/s]


epoch : 80 acc : 0.7023


100%|██████████| 196/196 [00:41<00:00,  4.68it/s]


epoch : 81 acc : 0.7224


100%|██████████| 196/196 [00:41<00:00,  4.68it/s]


epoch : 82 acc : 0.7252


100%|██████████| 196/196 [00:41<00:00,  4.68it/s]


epoch : 83 acc : 0.7251


100%|██████████| 196/196 [00:41<00:00,  4.67it/s]


epoch : 84 acc : 0.7257


100%|██████████| 196/196 [00:41<00:00,  4.68it/s]


epoch : 85 acc : 0.7254


100%|██████████| 196/196 [00:41<00:00,  4.68it/s]


epoch : 86 acc : 0.728


100%|██████████| 196/196 [00:41<00:00,  4.68it/s]


epoch : 87 acc : 0.7288


100%|██████████| 196/196 [00:41<00:00,  4.68it/s]


epoch : 88 acc : 0.7299


100%|██████████| 196/196 [00:41<00:00,  4.67it/s]


epoch : 89 acc : 0.7295


100%|██████████| 196/196 [00:41<00:00,  4.67it/s]


epoch : 90 acc : 0.7307


100%|██████████| 196/196 [00:41<00:00,  4.68it/s]


epoch : 91 acc : 0.7273


100%|██████████| 196/196 [00:41<00:00,  4.68it/s]


epoch : 92 acc : 0.7272


100%|██████████| 196/196 [00:41<00:00,  4.68it/s]


epoch : 93 acc : 0.7293


100%|██████████| 196/196 [00:41<00:00,  4.67it/s]


epoch : 94 acc : 0.7316


100%|██████████| 196/196 [00:41<00:00,  4.67it/s]


epoch : 95 acc : 0.7329


100%|██████████| 196/196 [00:41<00:00,  4.68it/s]


epoch : 96 acc : 0.7301


100%|██████████| 196/196 [00:41<00:00,  4.68it/s]


epoch : 97 acc : 0.7341


100%|██████████| 196/196 [00:41<00:00,  4.68it/s]


epoch : 98 acc : 0.7349


100%|██████████| 196/196 [00:41<00:00,  4.67it/s]


epoch : 99 acc : 0.733


100%|██████████| 196/196 [00:41<00:00,  4.67it/s]


epoch : 100 acc : 0.7311


100%|██████████| 196/196 [00:41<00:00,  4.68it/s]


epoch : 101 acc : 0.734


100%|██████████| 196/196 [00:41<00:00,  4.67it/s]


epoch : 102 acc : 0.7334


100%|██████████| 196/196 [00:41<00:00,  4.68it/s]


epoch : 103 acc : 0.7293


100%|██████████| 196/196 [00:41<00:00,  4.68it/s]


epoch : 104 acc : 0.7319


100%|██████████| 196/196 [00:41<00:00,  4.67it/s]


epoch : 105 acc : 0.7329


100%|██████████| 196/196 [00:41<00:00,  4.68it/s]


epoch : 106 acc : 0.7316


100%|██████████| 196/196 [00:41<00:00,  4.68it/s]


epoch : 107 acc : 0.7324


100%|██████████| 196/196 [00:41<00:00,  4.67it/s]


epoch : 108 acc : 0.7337


100%|██████████| 196/196 [00:41<00:00,  4.67it/s]


epoch : 109 acc : 0.7293


100%|██████████| 196/196 [00:41<00:00,  4.68it/s]


epoch : 110 acc : 0.7303


100%|██████████| 196/196 [00:41<00:00,  4.67it/s]


epoch : 111 acc : 0.7325


100%|██████████| 196/196 [00:41<00:00,  4.67it/s]


epoch : 112 acc : 0.7326


100%|██████████| 196/196 [00:42<00:00,  4.66it/s]


epoch : 113 acc : 0.7333


 23%|██▎       | 45/196 [00:09<00:33,  4.53it/s]


KeyboardInterrupt: ignored

In [None]:
model.load_state_dict(torch.load('/content/models/7349.pth'))
model.eval()
preds = []

for i, (img, label) in enumerate(tqdm(test_loader)):
    with torch.no_grad():
        img = img.cuda()
        label = label.cuda()
        pred = model(img)

        pred = torch.argmax(pred, dim=1)
        preds += pred.cpu().detach().tolist()

df = pd.DataFrame({'id': [i for i in range(len(preds))],
                   'class': preds})
df.to_csv("20181569.csv", index=False)

100%|██████████| 40/40 [00:03<00:00, 11.28it/s]
