In [13]:
import torch
import torch.nn as nn
class BasicBlock(nn.Module):
  def __init__(self, in_channels, out_channels, kernel_size=3):  # 3,64
    super(BasicBlock,self).__init__()
    self.c1 = nn.Conv2d(in_channels,out_channels,kernel_size=kernel_size,padding=1)  # 32,32,3 -> 32,32,64
    self.c2 = nn.Conv2d(out_channels,out_channels,kernel_size=kernel_size,padding=1) # 32,32,64 -> 32,32,64
    self.downsample = nn.Conv2d(in_channels,out_channels,kernel_size=1)  # 32,32,3 -> 32,32,64

    self.bn1 = nn.BatchNorm2d(num_features= out_channels)
    self.bn2 = nn.BatchNorm2d(num_features= out_channels)

    self.relu = nn.ReLU()

  def forward(self, x):
    x_ = x  # 입력데이터 보관 -- 스킵커넥션 용
    x = self.bn1(self.c1(x))
    x = self.relu(x)
    x = self.bn2(self.c2(x))
    x_ = self.downsample(x_)  # 출력과 채널수를 맞춤

    x += x_
    x = self.relu(x)

    return x

In [16]:
class ResNet(nn.Module):
  def __init__(self, num_class=10):
    super(ResNet,self).__init__()
    # 기본블럭
    self.b1=BasicBlock(3, 64)  #(32,32,3) -> (16,16,64)
    self.b2=BasicBlock(64, 128) # (16,16,64) ->(8,8,128)
    self.b3=BasicBlock(128, 256) # (8,8,128) -> (4,4,256)
    self.pool = nn.AvgPool2d(kernel_size=2,stride=2)

    # 분류기  4096 = 4*4*256
    self.fc1=nn.Linear(in_features=4096 , out_features= 2048)
    self.fc2=nn.Linear(in_features=2048 , out_features= 512)
    self.fc3=nn.Linear(in_features=512 , out_features=num_class)
    self.relu = nn.ReLU()
  def forward(self, x):
    x = self.pool(self.b1(x))
    x = self.pool(self.b2(x))
    x = self.pool(self.b3(x))
    x = torch.flatten(x, start_dim=1)
    x = self.relu(self.fc1(x))
    x = self.relu(self.fc2(x))
    x = self.fc3(x)
    return x

In [17]:
# 데이터 CIFAR10
# transformer

In [18]:
device = 'cuda' if torch.cuda.is_available() else 'cpu'
model = ResNet()
model.to(device)

ResNet(
  (b1): BasicBlock(
    (c1): Conv2d(3, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (c2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (downsample): Conv2d(3, 64, kernel_size=(1, 1), stride=(1, 1))
    (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (relu): ReLU()
  )
  (b2): BasicBlock(
    (c1): Conv2d(64, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (c2): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (downsample): Conv2d(64, 128, kernel_size=(1, 1), stride=(1, 1))
    (bn1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (bn2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (relu): ReLU()
  )
  (b3): BasicBlock(
    (c1): Conv2d(128, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))


In [19]:
from torch.utils.data.dataloader import DataLoader
from torchvision.datasets.cifar import CIFAR10
from torchvision.transforms import Compose,RandomCrop, RandomHorizontalFlip
from torchvision.transforms import ToTensor,RandomVerticalFlip,Resize,Normalize
from torch.optim.adam import Adam

In [None]:
# 데이터 로드

transforms = Compose([
    # Resize(32),
    RandomCrop((34,34),padding=4),
    RandomHorizontalFlip(p=0.5),
    ToTensor(),
    Normalize(mean=(0.4914,0.4822,0.4465),std=(0.247,0.243,0.261))
])

train_dataset = CIFAR10(root="./",train=True, download=True,transform=transforms)
test_dataset = CIFAR10(root="./",train=False, download=True,transform=transforms)
train_loader = DataLoader(train_dataset,batch_size=64,shuffle=True)
test_loader = DataLoader(test_dataset,batch_size=64,shuffle=False)


# 모델 학습
from tqdm import tqdm
lr = 1e-4
optim = Adam(model.parameters(), lr=lr)
for epoch in range(5):
  iterator = tqdm(train_loader)
  for data, label in iterator:
    optim.zero_grad()
    pred = model(data.to(device))
    loss = nn.CrossEntropyLoss()(pred,label.to(device))
    loss.backward()
    optim.step()

    iterator.set_description(f'epoch:{epoch+1}  loss:{loss.item()} ')

Files already downloaded and verified
Files already downloaded and verified


epoch:1  loss:0.8724726438522339 : 100%|██████████| 782/782 [24:38<00:00,  1.89s/it]
epoch:2  loss:0.7588767409324646 : 100%|██████████| 782/782 [24:47<00:00,  1.90s/it]
epoch:3  loss:0.6110846996307373 :  42%|████▏     | 325/782 [10:33<15:16,  2.00s/it]