In [1]:
import torch
from torch import nn
from torchvision import models, datasets, transforms
import time
from tqdm.auto import tqdm

In [2]:
# Обязательно к прочтению: тред на тему различных состояний нейронной сети в PyTorch
# https://stackoverflow.com/questions/51748138/pytorch-how-to-set-requires-grad-false
def set_requires_grad(model, value=False):
    for param in model.parameters():
        param.requires_grad = value

In [3]:
num_classes = 10
input_size = 224
batch_size = 64

In [4]:
model = models.resnet18(pretrained=True)
set_requires_grad(model, False)
num_ftrs = model.fc.in_features
model.fc = nn.Linear(num_ftrs, num_classes)

In [5]:
next(model.fc.parameters()).requires_grad

True

In [6]:
normalize = transforms.Compose([
    transforms.Resize(input_size),
    transforms.CenterCrop(input_size),
    transforms.ToTensor(),
    transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])])

In [7]:
trainset = datasets.CIFAR10(root='./data', train=True,
                            download=True, transform=normalize)
trainloader = torch.utils.data.DataLoader(trainset, batch_size=batch_size,
                                          shuffle=True, num_workers=2)

testset = datasets.CIFAR10(root='./data', train=False,
                           download=True, transform=normalize)
testloader = torch.utils.data.DataLoader(testset, batch_size=batch_size,
                                         shuffle=False, num_workers=2)

loaders = {'train': trainloader, 'val': testloader}
classes = ('plane', 'car', 'bird', 'cat', 'deer',
           'dog', 'frog', 'horse', 'ship', 'truck')

Files already downloaded and verified
Files already downloaded and verified


In [8]:
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

In [9]:
device

device(type='cuda', index=0)

In [10]:
# Optional: проверка выхода AdaptivePooling
# def print_hook(m, i):
#   print("Inside avgpool", i[0].shape)

# handle = model.avgpool.register_forward_pre_hook(print_hook)
# model(torch.ones(1,3,512,512))
# handle.remove()

In [11]:
model = model.to(device)
pretrain_optimizer = torch.optim.SGD(model.fc.parameters(),
                                     lr=0.001, momentum=0.9)
train_optimizer = torch.optim.SGD(model.parameters(), lr=0.001, momentum=0.9)

In [12]:
criterion = nn.CrossEntropyLoss()

In [13]:
def train_model(model, dataloaders, criterion, optimizer,
                phases, num_epochs=3):
    start_time = time.time()

    acc_history = {k: list() for k in phases}
    loss_history = {k: list() for k in phases}

    for epoch in range(num_epochs):
        print('Epoch {}/{}'.format(epoch, num_epochs - 1))
        print('-' * 10)

        # Each epoch has a training and validation phase
        for phase in phases:
            if phase == 'train':
                model.train()  # Set model to training mode
            else:
                model.eval()   # Set model to evaluate mode

            running_loss = 0.0
            running_corrects = 0

            # Iterate over data.
            n_batches = len(dataloaders[phase])
            for inputs, labels in tqdm(dataloaders[phase], total=n_batches):
                inputs = inputs.to(device)
                labels = labels.to(device)

                # zero the parameter gradients
                optimizer.zero_grad()

                # forward
                # track history if only in train
                with torch.set_grad_enabled(phase == 'train'):
                    outputs = model(inputs)
                    loss = criterion(outputs, labels)

                    _, preds = torch.max(outputs, 1)

                    # backward + optimize only if in training phase
                    if phase == 'train':
                        loss.backward()
                        optimizer.step()

                # statistics
                running_loss += loss.item() * inputs.size(0)
                running_corrects += torch.sum(preds == labels.data)

            epoch_loss = running_loss / len(dataloaders[phase].dataset)
            epoch_acc = running_corrects.double()
            epoch_acc /= len(dataloaders[phase].dataset)

            print('{} Loss: {:.4f} Acc: {:.4f}'.format(phase, epoch_loss,
                                                       epoch_acc))
            loss_history[phase].append(epoch_loss)
            acc_history[phase].append(epoch_acc)

        print()

    time_elapsed = time.time() - start_time
    print('Training complete in {:.0f}m {:.0f}s'.format(time_elapsed // 60,
                                                        time_elapsed % 60))

    return model, acc_history

In [14]:
# Сигнатура вызова функции train_model
# train_model(model, loaders, criterion, optimizer,
#             phases=['train', 'val'], num_epochs=num_epochs)

# Pretrain
# запустить предобучение модели на две эпохи
# train_model(model, loaders, criterion, pretrain_optimizer,
#             phases=['train', 'val'], num_epochs=2)

# Train
# запустить дообучение модели
set_requires_grad(model, True)
train_model(model, loaders, criterion, train_optimizer,
            phases=['train', 'val'], num_epochs=1)


Epoch 0/0
----------


  0%|          | 0/782 [00:00<?, ?it/s]

ERROR:root:Internal Python error in the inspect module.
Below is the traceback from this internal error.



Traceback (most recent call last):
  File "/usr/local/lib/python3.6/dist-packages/IPython/core/interactiveshell.py", line 3343, in run_code
    exec(code_obj, self.user_global_ns, self.user_ns)
  File "<ipython-input-14-ebe5c6ffb4eb>", line 14, in <module>
    phases=['train', 'val'], num_epochs=1)
  File "<ipython-input-13-34a14ffed359>", line 24, in train_model
    for inputs, labels in tqdm(dataloaders[phase], total=n_batches):
  File "/usr/local/lib/python3.6/dist-packages/tqdm/notebook.py", line 258, in __iter__
    for obj in it:
  File "/usr/local/lib/python3.6/dist-packages/tqdm/std.py", line 1195, in __iter__
    for obj in iterable:
  File "/usr/local/lib/python3.6/dist-packages/torch/utils/data/dataloader.py", line 521, in __next__
    data = self._next_data()
  File "/usr/local/lib/python3.6/dist-packages/torch/utils/data/dataloader.py", line 1203, in _next_data
    return self._process_data(data)
  File "/usr/local/lib/python3.6/dist-packages/torch/utils/data/dataloader.py

TypeError: object of type 'NoneType' has no len()