In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.optim import lr_scheduler
from torchvision import datasets, transforms, models
from torch.utils.data import random_split
import os
import time
import copy
import matplotlib.pyplot as plt
import os
import time
import copy
import torch
import torch.nn as nn
import torch.optim as optim
from torch.optim import lr_scheduler

In [4]:
# Set data directory
data_dir = '/Users/niezhiqi/Desktop/课程/深度学习/hw1/EX1/flower_dataset'

# Data augmentation and normalization for training and validation
data_transforms = transforms.Compose([
        # GRADED FUNCTION: Add five data augmentation methods, Normalizating and Tranform to tensor
        ### START SOLUTION HERE ###
        transforms.RandomResizedCrop(224),
        transforms.RandomHorizontalFlip(),
        transforms.RandomRotation(30),
        transforms.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2, hue=0.1),
        transforms.RandomGrayscale(p=0.1),
        transforms.ToTensor(),
        transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
        ### END SOLUTION HERE ###
])

# Load the entire dataset
full_dataset = datasets.ImageFolder(data_dir, data_transforms)

# Automatically split into 80% train and 20% validation
train_size = int(0.8 * len(full_dataset))
val_size = len(full_dataset) - train_size

train_dataset, val_dataset = random_split(full_dataset, [train_size, val_size])

# Use DataLoader for both train and validation datasets
batch_size = 32
train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=4)
val_loader = torch.utils.data.DataLoader(val_dataset, batch_size=batch_size, shuffle=True, num_workers=4)

dataloaders = {'train': train_loader, 'val': val_loader}
dataset_sizes = {'train': len(train_dataset), 'val': len(val_dataset)}

# Get class names from the dataset
class_names = full_dataset.classes

In [5]:
print(class_names)


['daisy', 'dandelion', 'rose', 'sunflower', 'tulip']


In [None]:
# Load pre-trained model and modify the last layer
model = models.resnet18(pretrained=True)

# GRADED FUNCTION: Modify the last fully connected layer of model
### START SOLUTION HERE ###
num_ftrs = model.fc.in_features
model.fc = nn.Linear(num_ftrs, len(class_names))  # 5 classes for flowers
### END SOLUTION HERE ###

# GRADED FUNCTION: Define the loss function
### START SOLUTION HERE ###
criterion = nn.CrossEntropyLoss() 
### END SOLUTION HERE ###

# GRADED FUNCTION: Define the optimizer
### START SOLUTION HERE ###
optimizer = optim.SGD(model.parameters(), lr=0.001, momentum=0.9)
### END SOLUTION HERE ###

# Learning rate scheduler
scheduler = lr_scheduler.StepLR(optimizer, step_size=7, gamma=0.1)

# Training function
def train_model(model, criterion, optimizer, scheduler, num_epochs=25):
    since = time.time()
    best_model_wts = copy.deepcopy(model.state_dict())
    best_acc = 0.0
    
    # Initialize lists to store training/validation statistics
    train_loss_history = []
    train_acc_history = []
    val_loss_history = []
    val_acc_history = []
    lr_history = []

    device = torch.device("mps" if torch.backends.mps.is_available() else "cpu")
    model = model.to(device)

    for epoch in range(num_epochs):
        print(f'Epoch {epoch}/{num_epochs - 1}')
        print('-' * 10)

        # Print and store learning rate for current epoch
        current_lr = optimizer.param_groups[0]['lr']
        lr_history.append(current_lr)
        print(f'Learning Rate: {current_lr:.6f}')

        for phase in ['train', 'val']:
            if phase == 'train':
                model.train()  # Set model to training mode
            else:
                model.eval()  # Set model to evaluate mode

            running_loss = 0.0
            running_corrects = 0

            # Iterate over data
            for inputs, labels in dataloaders[phase]:
                inputs = inputs.to(device)
                labels = labels.to(device)

                optimizer.zero_grad()

                with torch.set_grad_enabled(phase == 'train'):
                    outputs = model(inputs)
                    _, preds = torch.max(outputs, 1)
                    loss = criterion(outputs, labels)

                    if phase == 'train':
                        loss.backward()
                        optimizer.step()

                running_loss += loss.item() * inputs.size(0)
                running_corrects += torch.sum(preds == labels.data)

            epoch_loss = running_loss / dataset_sizes[phase]
            epoch_acc = running_corrects.float() / dataset_sizes[phase]

            # Store statistics
            if phase == 'train':
                train_loss_history.append(epoch_loss)
                train_acc_history.append(epoch_acc.cpu().numpy())
            else:
                val_loss_history.append(epoch_loss)
                val_acc_history.append(epoch_acc.cpu().numpy())

            print(f'{phase} Loss: {epoch_loss:.4f} Acc: {epoch_acc:.4f}')

            # Save the model if validation accuracy is the best so far
            if phase == 'val' and epoch_acc > best_acc:
                best_acc = epoch_acc
                best_model_wts = copy.deepcopy(model.state_dict())
                save_dir = 'Ex2/work_dir'
                os.makedirs(save_dir, exist_ok=True)
                torch.save(model.state_dict(), os.path.join(save_dir, 'best_model.pth'))
                
                # Also save the training statistics
                torch.save({
                    'train_loss_history': train_loss_history,
                    'train_acc_history': train_acc_history,
                    'val_loss_history': val_loss_history,
                    'val_acc_history': val_acc_history,
                    'lr_history': lr_history
                }, os.path.join(save_dir, 'training_stats.pth'))

        print()

    time_elapsed = time.time() - since
    print(f'Training complete in {time_elapsed // 60:.0f}m {time_elapsed % 60:.0f}s')
    print(f'Best val Acc: {best_acc:4f}')

    # Load best model weights
    model.load_state_dict(best_model_wts)
    
    # Plot training/validation statistics
    plt.figure(figsize=(15, 5))
    
    # Plot loss
    plt.subplot(1, 2, 1)
    plt.plot(train_loss_history, label='Train Loss')
    plt.plot(val_loss_history, label='Validation Loss')
    plt.title('Training and Validation Loss')
    plt.xlabel('Epoch')
    plt.ylabel('Loss')
    plt.legend()
    
    # Plot accuracy
    plt.subplot(1, 2, 2)
    plt.plot(train_acc_history, label='Train Accuracy')
    plt.plot(val_acc_history, label='Validation Accuracy')
    plt.title('Training and Validation Accuracy')
    plt.xlabel('Epoch')
    plt.ylabel('Accuracy')
    plt.legend()
    
    # Save plots
    plt.savefig(os.path.join(save_dir, 'training_curves.png'))
    plt.show()
    
    # Plot learning rate schedule
    plt.figure()
    plt.plot(lr_history)
    plt.title('Learning Rate Schedule')
    plt.xlabel('Epoch')
    plt.ylabel('Learning Rate')
    plt.savefig(os.path.join(save_dir, 'learning_rate.png'))
    plt.show()

    return model



In [None]:
# Train the model
device = torch.device("mps" if torch.backends.mps.is_available() else "cpu")
model = model.to(device)
model = train_model(model, criterion, optimizer, scheduler, num_epochs=30)


Epoch 0/29
----------
Learning Rate: 0.001000
train Loss: 0.9678 Acc: 0.6326
val Loss: 0.5004 Acc: 0.8316

Epoch 1/29
----------
Learning Rate: 0.001000
train Loss: 0.4883 Acc: 0.8240
val Loss: 0.3707 Acc: 0.8842

Epoch 2/29
----------
Learning Rate: 0.001000
train Loss: 0.3953 Acc: 0.8565
val Loss: 0.3598 Acc: 0.8877

Epoch 3/29
----------
Learning Rate: 0.001000
train Loss: 0.3735 Acc: 0.8661
val Loss: 0.2732 Acc: 0.9053

Epoch 4/29
----------
Learning Rate: 0.001000
train Loss: 0.3454 Acc: 0.8740
val Loss: 0.2471 Acc: 0.9070

Epoch 5/29
----------
Learning Rate: 0.001000
train Loss: 0.3151 Acc: 0.8845
val Loss: 0.2580 Acc: 0.9035

Epoch 6/29
----------
Learning Rate: 0.001000
train Loss: 0.2975 Acc: 0.8911
val Loss: 0.2463 Acc: 0.9105

Epoch 7/29
----------
Learning Rate: 0.000100
train Loss: 0.2532 Acc: 0.9100
val Loss: 0.2265 Acc: 0.9193

Epoch 8/29
----------
Learning Rate: 0.000100
train Loss: 0.2585 Acc: 0.9052
val Loss: 0.2398 Acc: 0.9140

Epoch 9/29
----------
Learning Rate: 