In [1]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
import matplotlib.pyplot as plt
import numpy as np
import sys
import os
import torch.nn.functional as F
from IPython.display import clear_output

In [2]:
import warnings
warnings.filterwarnings('ignore', category=FutureWarning)

In [3]:
train_ds = torch.load('../data/processed/train_ds.pt')
val_ds = torch.load('../data/processed/val_ds.pt')

In [4]:
# даталоуд -> по батчам

batch_size = 8
train_loader = DataLoader(train_ds, batch_size=batch_size, shuffle=True) # Теперь элемент - батч вида [8, 3, 224, 224], 8 тензоров
val_loader = DataLoader(val_ds, batch_size=batch_size, shuffle=True)

In [5]:
class CNN(nn.Module):
    def __init__(self, num_classes=10):
        super().__init__()
        
        self.conv1 = nn.Conv2d(3, 16, kernel_size=3, padding=1)  # 16 простых признаков, 3*3 квадр
        self.conv2 = nn.Conv2d(16, 32, kernel_size=3, padding=1)  # берём 16 простых признаков -> 32
        self.conv3 = nn.Conv2d(32, 64, 3, padding=1)
        self.pool = nn.MaxPool2d(2, 2) #сокр. в 2 раза
        self.dropout = nn.Dropout(0.5)
        final_size = 224 // 8
        self.fc_input_size = 64 * final_size * final_size
        self.fc = nn.Linear(self.fc_input_size, num_classes)
    
    def forward(self, x):
        x = self.pool(F.relu(self.conv1(x)))  # находим простые
        x = self.pool(F.relu(self.conv2(x)))  # находим сложнее
        x = self.pool(F.relu(self.conv3(x)))
        x = x.view(x.size(0), -1)  # выпрямляем
        x = self.dropout(x)
        x = self.fc(x)  # классифицируем
        return x


model = CNN(num_classes=10)

In [7]:
import torch.optim as optim

criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=0.0007)
best_acc=0

for i in range(7):
    # трейн
    model.train()
    r_loss = 0.0
    for imgs, labels in train_loader:
        optimizer.zero_grad()
        outputs = model(imgs)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()
        r_loss += loss.item() 
    
    # Валидация
    model.eval()
    correct = 0
    total = 0
    with torch.no_grad():
        for imgs, labels in val_loader:
            output = model(imgs)
            u, predicted = torch.max(output.data, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()
    
    accuracy = correct * 100 / total
    print(f"{i}: loss {r_loss/len(train_loader):.2f},accuracy: {accuracy:.2f}%")

    if accuracy > best_acc:
        best_acc = accuracy
        torch.save(model.state_dict(), f'../models/best_model_acc_{accuracy:.1f}.pt')

0: loss 1.36,accuracy: 55.43%
1: loss 1.15,accuracy: 59.27%
2: loss 1.00,accuracy: 61.26%
3: loss 0.88,accuracy: 60.19%
4: loss 0.77,accuracy: 63.42%
5: loss 0.68,accuracy: 62.46%
6: loss 0.62,accuracy: 62.96%


In [None]:
torch.save(model.state_dict(), '../models/final_model.pt')