In [1]:
import os

import torch
import torchvision
import torch.nn as nn
import torch.functional as F
import torchvision.models as models
import torchvision.transforms as transforms

# define Model

In [8]:
resnet_152 = models.resnet152(pretrained=True, progress=True)

## define device

In [9]:
if torch.cuda.is_available():
    device = torch.device('cuda')
else:
    device = torch.device('cpu')

resnet_152.to(device)

if device == torch.device('cuda') :
    print("device : 'cuda'")
else :
    print("device : 'cpu'")

device : 'cuda'


# Transform & DataLoader

In [3]:
transform = transforms.Compose(
    [transforms.ToTensor(),

    # channel number 만큼 수 전달.
    transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))]
)


batch_size = 12


data_path = '../ResNet_utils/data'

trainset = torchvision.datasets.CIFAR10(
    root=data_path, train=True,
    download=True, transform=transform
)
trainloader = torch.utils.data.DataLoader(
    trainset, batch_size=batch_size,
    shuffle=True, num_workers=2
)


testset = torchvision.datasets.CIFAR10(
    root=data_path, train=False,
    download=True, transform=transform
)
testloader = torch.utils.data.DataLoader(
    testset, batch_size=batch_size,
    shuffle=True, num_workers=2
)

Files already downloaded and verified
Files already downloaded and verified


In [16]:
fashion_mnist = torchvision.datasets.FashionMNIST(
    root=data_path, train=True,
    download=True, transform=transform
)

Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/train-images-idx3-ubyte.gz
Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/train-images-idx3-ubyte.gz to ../ResNet_utils/data\FashionMNIST\raw\train-images-idx3-ubyte.gz
26422272it [00:17, 1483934.07it/s]                              
Extracting ../ResNet_utils/data\FashionMNIST\raw\train-images-idx3-ubyte.gz to ../ResNet_utils/data\FashionMNIST\raw

Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/train-labels-idx1-ubyte.gz
Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/train-labels-idx1-ubyte.gz to ../ResNet_utils/data\FashionMNIST\raw\train-labels-idx1-ubyte.gz
29696it [00:00, 103765.73it/s]                          
Extracting ../ResNet_utils/data\FashionMNIST\raw\train-labels-idx1-ubyte.gz to ../ResNet_utils/data\FashionMNIST\raw

Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/t10k-images-idx3-ubyte.gz
Downloading http:

In [17]:
len(fashion_mnist)

60000

In [9]:
transform = transforms.Compose(
    [transforms.ToTensor(),

    # channel number 만큼 수 전달.
    transforms.Normalize((0.5), (0.5))]
)

mnist = torchvision.datasets.MNIST(
    root=data_path, train=True,
    download=True, transform=transform
)
mnistloader = torch.utils.data.DataLoader(
    mnist, batch_size=32,
    shuffle=True
)

In [10]:
for i, data in enumerate(mnistloader, 0):
    print(data[0].shape)

8, 28])
torch.Size([32, 1, 28, 28])
torch.Size([32, 1, 28, 28])
torch.Size([32, 1, 28, 28])
torch.Size([32, 1, 28, 28])
torch.Size([32, 1, 28, 28])
torch.Size([32, 1, 28, 28])
torch.Size([32, 1, 28, 28])
torch.Size([32, 1, 28, 28])
torch.Size([32, 1, 28, 28])
torch.Size([32, 1, 28, 28])
torch.Size([32, 1, 28, 28])
torch.Size([32, 1, 28, 28])
torch.Size([32, 1, 28, 28])
torch.Size([32, 1, 28, 28])
torch.Size([32, 1, 28, 28])
torch.Size([32, 1, 28, 28])
torch.Size([32, 1, 28, 28])
torch.Size([32, 1, 28, 28])
torch.Size([32, 1, 28, 28])
torch.Size([32, 1, 28, 28])
torch.Size([32, 1, 28, 28])
torch.Size([32, 1, 28, 28])
torch.Size([32, 1, 28, 28])
torch.Size([32, 1, 28, 28])
torch.Size([32, 1, 28, 28])
torch.Size([32, 1, 28, 28])
torch.Size([32, 1, 28, 28])
torch.Size([32, 1, 28, 28])
torch.Size([32, 1, 28, 28])
torch.Size([32, 1, 28, 28])
torch.Size([32, 1, 28, 28])
torch.Size([32, 1, 28, 28])
torch.Size([32, 1, 28, 28])
torch.Size([32, 1, 28, 28])
torch.Size([32, 1, 28, 28])
torch.Size([

In [19]:
len(mnist)

60000

In [11]:
transform = transforms.Compose(
    [transforms.ToTensor(),

    # channel number 만큼 수 전달.
    transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))]
)

imagenet = torchvision.datasets.ImageNet(
    root=data_path, split="train",
    download=True, transform=transform
)

RuntimeError: The dataset is no longer publicly accessible. You need to download the archives externally and place them in the root directory.

In [5]:
len(trainset)

50000

# define Optimizer

In [6]:
import torch.optim as optim

criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(resnet_152.parameters(), lr=0.001)

# Training (with model CheckPoint)

In [12]:
model_path = '../ResNet_utils/trained_models/checkpoints/'

for epoch in range(2):

    running_loss = 0.0
    for i, data in enumerate(trainloader, 0):
        # [inputs, labels]의 목록인 data로부터 입력을 받기
        inputs, labels = data[0].to(device), data[1].to(device)

        # Gradient parameter to Zeros
        optimizer.zero_grad()

        # foward/back propagation + optimization
        outputs = resnet_152(inputs)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()

        # print info
        running_loss += loss.item()
        if i % 2000 == 1999:
            print(f"[{epoch + 1}, {i + 1}] loss: {running_loss / 2000}")

            model_path_ = os.path.join(model_path, f"ResNet152_{epoch}_{i+1}_{running_loss/2000}.pth")
            torch.save({
                'epoch' : epoch,
                'model_state_dict' : resnet_152.state_dict(),
                'optimizer_state_dict' : optimizer.state_dict(),
                'loss' : running_loss / 2000
            }, model_path_)
            
            running_loss = 0.0
        
        torch.cuda.empty_cache()

print('Finished Training')

[1, 2000] loss: 1.574141943693161
[1, 4000] loss: 1.3726720532774925
[2, 2000] loss: 1.2510065106451511
[2, 4000] loss: 1.1630401315540075
Finished Training


# Saving & Loading Model for Inference

## Save/Load ***state_dict*** (Recommended)

### Save:

In [None]:
torch.save(model.state_dict, model_path)

### Load:

In [None]:
resnet_152 = models.resnet152()
resnet_152.load_state_dict(torch.load(model_path))

## Save/Load Entire Model

### Save:

In [None]:
torch.save(resnet_152, model_path) # => to .pt or .pth

### Load:

In [None]:
resnet_152 = torch.load(model_path)

## Saving & Loading a General Checkpoint for Inference and/or Resuming Training

### Save:

In [None]:
torch.save({
    'epoch' : epoch,
    'model_state_dict' : resnet_152.state_dict(),
    'optimizer_state_dict' : optimizer.state_dict(),
    'loss' : loss
}, model_path)

### Load:

In [None]:
resnet_152 = models.resnet152()
optimizer = optim.SGD()

checkpoint = torch.load(model_path)
resnet_152.load_state_dict(checkpoint['model_state_dict'])
optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
epoch = checkpoint['epoch']
loss = checkpoint['loss']

## Saving Multiple Models in One File

### Save:

In [None]:
torch.save({
    'modelA_state_dict' : resnet_152A.state_dict(),
    'modelB_state_dict' : resnet_152B.state_dict(),
    'optimizerA_state_dict' : optimizerA.state_dict(),
    'optimizerB_state_dict' : optimizerB.state_dict()
}, model_path)

### Load:

In [None]:
resnet_152A = models.resnet152()
resnet_152B = models.resnet152()
optimizerA = optim.SGD()
optimizerB = optim.Adagrad()

checkpoint = torch.load(model_path)
resnet_152A.load_state_dict(chekcpoint['modelA_state_dict'])
resnet_152B.load_state_dict(checkpoint['modelB_state_dict'])
optimizerA.load_state_dict(checkpoint['optimizerA_state_dict'])
optimizerB.load_state_dict(checkpoint['optimizerB_state_dict'])

