In [13]:
import torch
import torchvision
import torchvision.transforms as transforms
import numpy as np
from torch.utils.data.sampler import SubsetRandomSampler

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

#dataset
transform = transforms.Compose([transforms.Resize((32,32)), transforms.ToTensor(), transforms.Normalize((0.5,0.5,0.5), (0.5,0.5,0.5))])

train = torchvision.datasets.CIFAR10(root='./data', train=True, download=True, transform=transform)
valid = torchvision.datasets.CIFAR10(root='./data', train=True, download=True, transform=transform)
test = torchvision.datasets.CIFAR10(root='./data', train=False, download=True, transform=transform)

num = len(train)
indices = list(range(num))
split = int(np.floor(0.1 * num))
train_idx, valid_idx = indices[split:], indices[:split]
train_sampler = SubsetRandomSampler(train_idx)
valid_sampler = SubsetRandomSampler(valid_idx)

trainloader = torch.utils.data.DataLoader(train, batch_size=64, sampler=train_sampler)
validloader = torch.utils.data.DataLoader(valid, batch_size=64, sampler=valid_sampler)
testloader = torch.utils.data.DataLoader(test, batch_size=64, shuffle=False)

classes = ('plane', 'car', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck')


#Model
from torch import nn
from torch.nn import functional as F
import torch.optim as optim

class Net(nn.Module):
    def __init__(self):
      super(Net, self).__init__()
      self.fc1 = nn.Linear(1024*3, 512*3).cuda(device)
      self.fc2 = nn.Linear(512*3, 256*3).cuda(device)
      self.fc3 = nn.Linear(256*3, 128*3).cuda(device)
      self.fc4 = nn.Linear(128*3, 64*3).cuda(device)
      self.fc5 = nn.Linear(64*3, 32*3).cuda(device)
      self.fc6 = nn.Linear(32*3, 10).cuda(device)
    
    def forward(self, x):
      x = x.view(-1, 1024*3)
      x.cuda(device)
      x = F.relu(self.fc1(x))
      x = F.relu(self.fc2(x))
      x = F.relu(self.fc3(x))
      x = F.relu(self.fc4(x))
      x = F.relu(self.fc5(x))
      x = self.fc6(x)
      return F.log_softmax(x, dim=1)

model = Net()

#hyperparameter
criterion = nn.CrossEntropyLoss()
optimizer = optim.SGD(model.parameters(), lr=0.001)

#Train
for epoch in range(1,101):
  acc=0
  train_loss = 0.0
  for image, label in trainloader:
    image, label = image.to(device), label.to(device)
    optimizer.zero_grad()
    output = model(image)
    loss = criterion(output, label)
    loss.backward()
    optimizer.step()
    train_loss += loss.item()
    pred = output.argmax(dim=1, keepdim=True)
    acc += pred.eq(label.view_as(pred)).sum().item()

  if epoch % 10 == 0:
    correct = 0
    valid_loss = 0.0
    for image, label in validloader:
      image, label = image.to(device), label.to(device)
      output = model(image)
      loss = criterion(output, label)
      valid_loss += loss.item()
      vpred = output.argmax(dim=1, keepdim=True)
      correct += vpred.eq(label.view_as(vpred)).sum().item()
    print('Epoch:{}/{}, train_loss:{}, train_Accuracy:{}, valid_loss:{}, validation Accuracy: {}'.format(epoch, 100, train_loss / 45000, 100 * acc / 45000, valid_loss / 5000, 100 * correct / 5000))

#Test acc
correct = 0
with torch.no_grad():
  for image, label in testloader:
    image, label = image.to(device), label.to(device)
    output = model(image)
    tpred = output.argmax(dim=1, keepdim=True)
    correct += tpred.eq(label.view_as(tpred)).sum().item()

print('Test Accuracy: {}'.format(100 * correct / len(testloader.dataset)))

Files already downloaded and verified
Files already downloaded and verified
Files already downloaded and verified
Epoch:10/100, train_loss:0.03598873875935872, train_Accuracy:13.102222222222222, valid_loss:0.03634411196708679, validation Accuracy: 14.3
Epoch:20/100, train_loss:0.03585971216625637, train_Accuracy:21.72222222222222, valid_loss:0.03619255495071411, validation Accuracy: 23.18
Epoch:30/100, train_loss:0.034541379822625055, train_Accuracy:19.435555555555556, valid_loss:0.03463901958465576, validation Accuracy: 20.18
Epoch:40/100, train_loss:0.03183430882294973, train_Accuracy:23.626666666666665, valid_loss:0.03198616662025452, validation Accuracy: 24.74
Epoch:50/100, train_loss:0.029635790440771314, train_Accuracy:28.982222222222223, valid_loss:0.029899160361289978, validation Accuracy: 29.54
Epoch:60/100, train_loss:0.028214421060350207, train_Accuracy:33.111111111111114, valid_loss:0.02862076075077057, validation Accuracy: 33.74
Epoch:70/100, train_loss:0.02687367467615339