In [1]:
import copy

import torch
import torch.nn as nn
import torchvision
import torchvision.transforms as transforms
import torchvision.models as models
import matplotlib.pyplot as plt


if torch.backends.mps.is_available():
    device = torch.device("mps")   # GPU на Mac (Apple Silicon або AMD)
elif torch.cuda.is_available():
    device = torch.device("cuda")  # NVIDIA GPU (не на Mac)
else:
    device = torch.device("cpu")   # fallback

In [2]:
batch_size=32
num_epochs = 20
patience = 5
counter = 0
lr=1e-3
weight_decay=1e-4
PATH='./CIFAR100_ResNET.pth'

In [3]:
transform = transforms.Compose([transforms.ToTensor(),
                                transforms.RandomHorizontalFlip(),
                                transforms.RandomCrop(32, padding=4),
                                transforms.ColorJitter(0.2, 0.2, 0.2, 0.1),
                                transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])])

transform_test = transforms.Compose([transforms.ToTensor(),
                                transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])])

train_dataset = torchvision.datasets.CIFAR100(root='./cifar100',
                                           download=True,
                                           transform=transform,
                                           train=True)

test_dataset = torchvision.datasets.CIFAR100(root='./cifar100',
                                             download=True,
                                             transform=transform_test,
                                             train=False)

In [4]:
train_loader = torch.utils.data.DataLoader(train_dataset,
                                           batch_size=batch_size,
                                           shuffle=True,
                                           num_workers=8)

test_loader = torch.utils.data.DataLoader(test_dataset,
                                          batch_size=batch_size,
                                          shuffle=True)

In [5]:
model = models.resnet50(weights=models.ResNet50_Weights.DEFAULT)
model.fc = nn.Linear(model.fc.in_features, 100)
model = model.to(device)

for params in model.parameters():
    params.requires_grad = False
model.fc = model.fc.requires_grad_(True)

In [8]:
def run_one_epoch(model_f, train_loader_f, device_f, criterion_f, optimizer_f):
    model_f.train()
    running_loss, correct = 0.0, 0
    for images_f, labels_f in train_loader_f:
        images_f, labels_f = images_f.to(device_f), labels_f.to(device)

        outputs_f = model(images_f)

        l = criterion_f(outputs_f, labels_f)
        l.backward()

        optimizer_f.step()
        optimizer_f.zero_grad()
        running_loss += l.item()
        correct += (outputs_f.argmax(1) == labels_f).sum().item()

    return running_loss/len(train_loader_f), correct/len(train_loader_f.dataset)



In [9]:

criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.AdamW(model.parameters(), lr=lr, weight_decay=weight_decay)

best_train_loss = float('inf')
best_model_wts = copy.deepcopy(model.state_dict())

train_losses, train_accs = [], []


for epoch in range(num_epochs):
    training_loss, training_accuracy = run_one_epoch(model, train_loader, device, criterion, optimizer)

    train_losses.append(training_loss)
    train_accs.append(training_accuracy)

    print(f"Epoch [{epoch+1}/{num_epochs}] "
          f"Train Loss: {training_loss:.4f}, Train Acc: {training_accuracy:.4f} ")

    if training_loss < best_train_loss:
        best_train_loss = training_loss
        best_model_wts = copy.deepcopy(model.state_dict())
        counter = 0
    else:
        counter +=1
        if counter > patience:
            print('Early stop triggered')
            break


Epoch [1/20] Train Loss: 4.1243, Train Acc: 0.1251 
Epoch [2/20] Train Loss: 3.8663, Train Acc: 0.1679 
Epoch [3/20] Train Loss: 3.8183, Train Acc: 0.1735 
Epoch [4/20] Train Loss: 3.8071, Train Acc: 0.1740 
Epoch [5/20] Train Loss: 3.7751, Train Acc: 0.1801 
Epoch [6/20] Train Loss: 3.7835, Train Acc: 0.1841 
Epoch [7/20] Train Loss: 3.7996, Train Acc: 0.1828 
Epoch [8/20] Train Loss: 3.7933, Train Acc: 0.1851 
Epoch [9/20] Train Loss: 3.8012, Train Acc: 0.1814 
Epoch [10/20] Train Loss: 3.7799, Train Acc: 0.1852 
Epoch [11/20] Train Loss: 3.7817, Train Acc: 0.1874 
Early stop triggered


In [None]:

model.load_state_dict(best_model_wts)

for params in model.parameters():
    params.requires_grad = True

fine_tune_epoch = 10
for epoch in range(fine_tune_epoch):
    training_loss, training_accuracy = run_one_epoch(model, train_loader, device, criterion, optimizer)
    print(f"Epoch [{epoch+1}/{num_epochs}] "
          f"Train Loss: {training_loss:.4f}, Train Acc: {training_accuracy:.4f} ")

torch.save(model, PATH)
print('Training finished')


In [None]:
load_model = models.resnet50()
load_model.load_state_dict(torch.load(PATH))
load_model.eval()

with torch.no_grad():
    n_correct = 0
    n_samples = len(test_loader.dataset)
    for images, labels in test_loader:
        images, labels = images.to(device), labels.to(device)
        outputs = model(images)

        _, predicted = torch.max(outputs, 1)
        n_correct += (predicted == labels).sum().item()

    acc = n_correct / n_samples
    print(f'Accuracy of the network on the {n_samples} test image: {100*acc}%')


In [None]:
plt.figure(figsize=(12,5))

plt.subplot(1,2,1)
plt.plot(train_losses, label='Train Loss')
plt.xlabel('Epoch')
plt.ylabel('Loss')
plt.title('Loss Curve (Frozen + Fine-tune)')
plt.legend()

plt.subplot(1,2,2)
plt.plot(train_accs, label='Train Acc')
plt.xlabel('Epoch')
plt.ylabel('Accuracy')
plt.title('Accuracy Curve (Frozen + Fine-tune)')
plt.legend()

plt.show()