In [13]:
import torch
import torchvision
import torch.nn as nn
from torch.utils.data import DataLoader
from torchvision import datasets
from torchvision import transforms as T
from torchvision import io
import torchutils as tu
import json
import numpy as np
import matplotlib.pyplot as plt
import os
from PIL import Image
from torch.utils.data import TensorDataset, DataLoader, random_split

from torchvision import models
from torchvision.models import ResNet50_Weights
from torchvision.datasets import ImageFolder
from sklearn.metrics import precision_score, recall_score, f1_score


In [3]:
model = torchvision.models.resnet50(weights=ResNet50_Weights.DEFAULT)

In [4]:
transform = T.Compose([
    T.Resize(232),
    T.CenterCrop(224), 
    T.ToTensor(),       
    T.Normalize(mean=[0.485, 0.456, 0.406],
                std=[0.229, 0.224, 0.225])
])

In [11]:
train_ds = ImageFolder('images/seg_train', transform=transform)
valid_ds = ImageFolder('images/seg_test', transform=transform)

train_loader = DataLoader(train_ds, batch_size=32, shuffle=True)
valid_loader = DataLoader(valid_ds, batch_size=32, shuffle=False)

In [7]:
model

ResNet(
  (conv1): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
  (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (relu): ReLU(inplace=True)
  (maxpool): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
  (layer1): Sequential(
    (0): Bottleneck(
      (conv1): Conv2d(64, 64, kernel_size=(1, 1), stride=(1, 1), bias=False)
      (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (conv3): Conv2d(64, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)
      (bn3): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
      (downsample): Sequential(
        (0): Conv2d(64, 256, kernel_size=(1, 1), stride=(1, 

In [12]:
for param in model.parameters():
    param.requires_grad = False

model.fc = nn.Linear(in_features=2048, out_features=6)  

for param in model.fc.parameters():
    param.requires_grad = True

In [None]:
optimizer = torch.optim.Adam(filter(lambda p: p.requires_grad, model.parameters()), lr = 0.001) 
criterion = torch.nn.CrossEntropyLoss()

In [14]:
def fit(model, train_loader, valid_loader, optimizer, criterion, n_epochs, device='cpu'):
    tr_loss, vl_loss = [], []
    tr_acc, vl_acc = [], []
    tr_prec, vl_prec = [], []
    tr_rec, vl_rec = [], []
    tr_f1, vl_f1 = [], []

    model.to(device)

    for n in range(n_epochs):
        model.train()
        tr_loss_iter = []
        all_preds, all_labels = [], []

        for samples, labels in train_loader:
            samples, labels = samples.to(device), labels.to(device)
            optimizer.zero_grad()
            preds = model(samples)
            loss = criterion(preds, labels)
            loss.backward()
            optimizer.step()
            tr_loss_iter.append(loss.item())

            pred_classes = preds.argmax(dim=1)
            all_preds.append(pred_classes.cpu())
            all_labels.append(labels.cpu())

        # объединяем все батчи
        all_preds = torch.cat(all_preds)
        all_labels = torch.cat(all_labels)

        tr_loss.append(np.mean(tr_loss_iter))
        tr_acc.append((all_preds == all_labels).sum().item() / len(all_labels))
        tr_prec.append(precision_score(all_labels, all_preds, average='macro'))
        tr_rec.append(recall_score(all_labels, all_preds, average='macro'))
        tr_f1.append(f1_score(all_labels, all_preds, average='macro'))

        model.eval()
        vl_loss_iter = []
        all_preds, all_labels = [], []

        with torch.no_grad():
            for samples, labels in valid_loader:
                samples, labels = samples.to(device), labels.to(device)
                preds = model(samples)
                loss = criterion(preds, labels)
                vl_loss_iter.append(loss.item())

                pred_classes = preds.argmax(dim=1)
                all_preds.append(pred_classes.cpu())
                all_labels.append(labels.cpu())

        all_preds = torch.cat(all_preds)
        all_labels = torch.cat(all_labels)

        vl_loss.append(np.mean(vl_loss_iter))
        vl_acc.append((all_preds == all_labels).sum().item() / len(all_labels))
        vl_prec.append(precision_score(all_labels, all_preds, average='macro'))
        vl_rec.append(recall_score(all_labels, all_preds, average='macro'))
        vl_f1.append(f1_score(all_labels, all_preds, average='macro'))

        print(f"Epoch {n+1}/{n_epochs}: "
              f"tr_loss={tr_loss[-1]:.3f}, vl_loss={vl_loss[-1]:.3f}, "
              f"tr_acc={tr_acc[-1]:.3f}, vl_acc={vl_acc[-1]:.3f}, "
              f"tr_f1={tr_f1[-1]:.3f}, vl_f1={vl_f1[-1]:.3f}")

    return tr_loss, vl_loss, tr_acc, vl_acc, tr_prec, vl_prec, tr_rec, vl_rec, tr_f1, vl_f1