In [1]:
import torch
import torchvision
import torchvision.transforms as transforms

# Seed

In [2]:
seed_number = 44

torch.manual_seed(seed_number)
torch.cuda.manual_seed(seed_number)
torch.cuda.manual_seed_all(seed_number) # if use multi-GPU
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False
#np.random.seed(seed_number)
#random.seed(seed_number)

# CIFAR10 data download

In [3]:
train_dataset = torchvision.datasets.CIFAR10('./data/', train=True, download=True)
test_dataset  = torchvision.datasets.CIFAR10('./data/', train=False)

Files already downloaded and verified


# Load dataset

In [4]:
X      = train_dataset.data
X_test = test_dataset.data

X      = torch.tensor(X, dtype=torch.uint8)
X_test = torch.tensor(X_test, dtype=torch.uint8)

print(X.shape)
print(X_test.shape)

torch.Size([50000, 32, 32, 3])
torch.Size([10000, 32, 32, 3])


In [5]:
X = X.permute(0,3,1,2)
X_test = X_test.permute(0,3,1,2)

print(X.shape)
print(X_test.shape)

torch.Size([50000, 3, 32, 32])
torch.Size([10000, 3, 32, 32])


In [6]:
y      = train_dataset.targets
y_test = test_dataset.targets

y      = torch.tensor(y, dtype=torch.uint8)
y_test = torch.tensor(y_test, dtype=torch.uint8)

print(y.shape)
print(y_test.shape)

torch.Size([50000])
torch.Size([10000])


# Normalize

In [7]:
X = X / 255.
X_test  = X_test / 255.

# ImageNet statistics
X = torchvision.transforms.Normalize(mean=(0.485,0.456,0.406), std=(0.229,0.224,0.225))(X)
X_test  = torchvision.transforms.Normalize(mean=(0.485,0.456,0.406), std=(0.229,0.224,0.225))(X_test)

print(X.shape)
print(X_test.shape)

torch.Size([50000, 3, 32, 32])
torch.Size([10000, 3, 32, 32])


# Split dataset

In [8]:
from sklearn.model_selection import train_test_split
X_train, X_valid, y_train, y_valid = train_test_split(X, y, test_size=0.2, random_state=seed_number)

print(X_train.shape, y_train.shape)
print(X_valid.shape, y_valid.shape)

torch.Size([40000, 3, 32, 32]) torch.Size([40000])
torch.Size([10000, 3, 32, 32]) torch.Size([10000])


# Resize

In [9]:
X_train = torchvision.transforms.Resize((224,224))(X_train)
X_valid = torchvision.transforms.Resize((224,224))(X_valid)
X_test  = torchvision.transforms.Resize((224,224))(X_test)

print(X_train.shape)
print(X_valid.shape)
print(X_test.shape)



torch.Size([40000, 3, 224, 224])
torch.Size([10000, 3, 224, 224])
torch.Size([10000, 3, 224, 224])


# Data loader

In [10]:
device = 'cuda' if torch.cuda.is_available() else 'cpu'
print(device)

X_train = X_train.to(device=device)
X_valid = X_valid.to(device=device)
X_test  = X_test.to(device=device)

y_train = y_train.to(device=device)
y_valid = y_valid.to(device=device)
y_test  = y_test.to(device=device)

cpu


In [11]:
train_ds = torch.utils.data.TensorDataset(X_train, y_train)
valid_ds = torch.utils.data.TensorDataset(X_valid, y_valid)

train_loader = torch.utils.data.DataLoader(train_ds, batch_size=512, shuffle=True)
valid_loader = torch.utils.data.DataLoader(valid_ds, batch_size=512, shuffle=False)

# Model

In [12]:
dir(torchvision.models)

['AlexNet',
 'AlexNet_Weights',
 'ConvNeXt',
 'ConvNeXt_Base_Weights',
 'ConvNeXt_Large_Weights',
 'ConvNeXt_Small_Weights',
 'ConvNeXt_Tiny_Weights',
 'DenseNet',
 'DenseNet121_Weights',
 'DenseNet161_Weights',
 'DenseNet169_Weights',
 'DenseNet201_Weights',
 'EfficientNet',
 'EfficientNet_B0_Weights',
 'EfficientNet_B1_Weights',
 'EfficientNet_B2_Weights',
 'EfficientNet_B3_Weights',
 'EfficientNet_B4_Weights',
 'EfficientNet_B5_Weights',
 'EfficientNet_B6_Weights',
 'EfficientNet_B7_Weights',
 'EfficientNet_V2_L_Weights',
 'EfficientNet_V2_M_Weights',
 'EfficientNet_V2_S_Weights',
 'GoogLeNet',
 'GoogLeNetOutputs',
 'GoogLeNet_Weights',
 'Inception3',
 'InceptionOutputs',
 'Inception_V3_Weights',
 'MNASNet',
 'MNASNet0_5_Weights',
 'MNASNet0_75_Weights',
 'MNASNet1_0_Weights',
 'MNASNet1_3_Weights',
 'MaxVit',
 'MaxVit_T_Weights',
 'MobileNetV2',
 'MobileNetV3',
 'MobileNet_V2_Weights',
 'MobileNet_V3_Large_Weights',
 'MobileNet_V3_Small_Weights',
 'RegNet',
 'RegNet_X_16GF_Weights'

In [14]:
backbone = torchvision.models.resnet34(weights="IMAGENET1K_V1")
backbone

Downloading: "https://download.pytorch.org/models/resnet34-b627a593.pth" to C:\Users\user/.cache\torch\hub\checkpoints\resnet34-b627a593.pth
100.0%


ResNet(
  (conv1): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
  (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (relu): ReLU(inplace=True)
  (maxpool): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
  (layer1): Sequential(
    (0): BasicBlock(
      (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
      (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    )
    (1): BasicBlock(
      (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
  

In [15]:
class Net(torch.nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        
        # https://pytorch.org/vision/stable/models.html
        self.backbone = torchvision.models.resnet34(weights="IMAGENET1K_V1")
                
        # modified layer
        self.backbone.fc = torch.nn.Linear(512, 10)
        
    def forward(self, xb):       
        out = self.backbone(xb)
        return out

In [16]:
model = Net().to(device)
print(model)

Net(
  (backbone): ResNet(
    (conv1): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
    (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (relu): ReLU(inplace=True)
    (maxpool): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
    (layer1): Sequential(
      (0): BasicBlock(
        (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
        (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (relu): ReLU(inplace=True)
        (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
        (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      )
      (1): BasicBlock(
        (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
        (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_runnin

# Cost function

In [17]:
criterion = torch.nn.CrossEntropyLoss()

# Optimizer

In [18]:
pretrained_params = [param for name, param in model.named_parameters() if 'fc' not in str(name)]
optimizer = torch.optim.Adam(
    [{'params': pretrained_params},
     {'params': model.backbone.fc.parameters(), 'lr': 1e-4}], 
    lr=1e-5
    )

# Training

In [19]:
import time 

for epoch in range(3):
    start = time.time()
    train_loss, train_count = 0., 0.
    
    for xb, yb in train_loader:
        prediction = model(xb)
        loss = criterion(prediction, yb)
        train_loss += loss.item()*len(yb)
        train_count += len(yb)
        
        loss.backward()
        optimizer.step()
        optimizer.zero_grad()
    
        
    with torch.no_grad():
        valid_loss = 0
        valid_accuracy, valid_count = 0., 0.
        
        for xb, yb in valid_loader:
            prediction = model(xb)
            valid_loss += criterion(prediction, yb)*len(yb)
            valid_accuracy += prediction.data.max(1)[1].eq(yb.data).sum()
            valid_count += len(yb)
        
    train_loss /= train_count
    valid_loss /= valid_count    
    valid_accuracy /= valid_count
            
    print(f"======== Epoch {epoch+1} =======")
    print(f"Loss => train:{train_loss:.5f}, valid:{valid_loss:.5f}")
    print(f"Accuracy => {valid_accuracy*100:.2f}%, Elapsed time => {time.time()-start:.3f} sec")
    print("=======================================\n")

Loss => train:1.33548, valid:0.65280
Accuracy => 82.03%, Elapsed time => 6227.444 sec

Loss => train:0.44356, valid:0.34071
Accuracy => 89.76%, Elapsed time => 8428.341 sec

Loss => train:0.25064, valid:0.25375
Accuracy => 91.85%, Elapsed time => 9592.719 sec

