In [None]:
import torch
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
device

In [None]:
epochs = 20
batch_size = 16
save_steps = 10
num_workers = 4
lr = 0.001
lr_step = 10

In [None]:
class RunningAverage():

    def __init__(self):
        self.steps = 0
        self.loss_sum = 0
        self.acc_sum = 0
    
    def update(self, loss, acc):
        self.loss_sum += loss
        self.acc_sum += acc
        self.steps += 1
    
    def __call__(self):
        return self.loss_sum/float(self.steps), self.acc_sum/float(self.steps)

In [None]:
import os

def save_checkpoint(state):
    
    save_dir = "pretrained_tm_weights"
    if not os.path.exists(save_dir):
        os.mkdir(save_dir)
        
    filepath = os.path.join(save_dir, 'best.pth')
    torch.save(state, filepath)

In [None]:
import os
import json
import random
from PIL import Image

import torch
from torch.utils.data import Dataset
import torchvision.transforms as transforms


def read_split_data(root, val_rate=0.2):
    
    assert os.path.exists(root), "dataset root: {} does not exist.".format(root)

    flower_class = [cla for cla in os.listdir(root) if os.path.isdir(os.path.join(root, cla))]
    flower_class.sort()
    class_indices = dict((k, v) for v, k in enumerate(flower_class))
    print(class_indices) 

    train_images_path = []
    train_images_label = []
    val_images_path = []
    val_images_label = []
    every_class_num = []
   
    for cla in flower_class:
        cla_path = os.path.join(root, cla)
        images = [os.path.join(root, cla, i) for i in os.listdir(cla_path)]
        images.sort()
        image_class = class_indices[cla]
        
        every_class_num.append(len(images))
        val_path = random.sample(images, k=int(len(images) * val_rate))
        
        # 划分训练集 和 验证集
        for img_path in images:
            if img_path in val_path:  
                val_images_path.append(img_path)
                val_images_label.append(image_class)
            else:  
                train_images_path.append(img_path)
                train_images_label.append(image_class)

    print("{} images were found in the dataset.".format(sum(every_class_num)))
    print("{} images for training.".format(len(train_images_path)))
    print("{} images for validation.".format(len(val_images_path)))
    assert len(train_images_path) > 0, "number of training images must greater than 0."
    assert len(val_images_path) > 0, "number of validation images must greater than 0."

    return train_images_path, train_images_label, val_images_path, val_images_label


class MyDataSet(Dataset):
    """自定义数据集"""

    def __init__(self, images_path, images_class, transform=None):
        self.images_path = images_path
        self.images_class = images_class
        self.transform = transform

    def __len__(self):
        return len(self.images_path)

    def __getitem__(self, item):
        img = Image.open(self.images_path[item])
        label = self.images_class[item]

        if self.transform is not None:
            img = self.transform(img)

        return torch.as_tensor(img), torch.as_tensor(label)


data_root = "/kaggle/input/flowers/flower_photos"
train_images_path, train_images_label, val_images_path, val_images_label = read_split_data(data_root)

# =============================== Transform ===============================
img_size = 224
train_transform = transforms.Compose([transforms.RandomResizedCrop(img_size),
                                        transforms.RandomHorizontalFlip(),
                                        transforms.ToTensor(),
                                        transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])])

val_transform = transforms.Compose([transforms.Resize(int(img_size * 1.143)),
                                      transforms.CenterCrop(img_size),
                                      transforms.ToTensor(),
                                      transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])])

# =============================== DataSet ===============================

train_dataset = MyDataSet(images_path=train_images_path,
                          images_class=train_images_label,
                          transform=train_transform)

val_dataset = MyDataSet(images_path=val_images_path,
                        images_class=val_images_label,
                        transform=val_transform)

# =============================== DataLoader ===============================

train_loader = torch.utils.data.DataLoader(train_dataset,
                                           batch_size=batch_size,
                                           shuffle=True,
                                           pin_memory=True,
                                           num_workers=num_workers)

val_loader = torch.utils.data.DataLoader(val_dataset,
                                         batch_size=batch_size,
                                         shuffle=False,
                                         pin_memory=True,
                                         num_workers=num_workers)

print("len(train_loader) = {}".format(len(train_loader)))
print("len(val_loader) = {}".format(len(val_loader)))

In [None]:
import numpy as np
import torch.nn as nn
import torch.optim as optim
from torch.optim.lr_scheduler import StepLR


def train_and_evaluate(model, train_dataloader, val_dataloader, criteria, optimizer, scheduler, epochs, save_steps):
    
    best_val_acc = 0.0
    for epoch in range(epochs):
        
        print("Epoch {}/{}".format(epoch + 1, epochs))

        # ---------- train ------------
        
        model.train()
        metric_avg = RunningAverage()
        
        for i, (train_batch, labels_batch) in enumerate(train_dataloader):

            train_batch, labels_batch = train_batch.to(device), labels_batch.to(device)
            output_batch = model(train_batch)
            loss = criteria(output_batch, labels_batch)

            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            
            if i % save_steps == 0:
                output_batch = output_batch.data.cpu().numpy()
                labels_batch = labels_batch.data.cpu().numpy()

                predict_labels = np.argmax(output_batch, axis=1)
                acc = np.sum(predict_labels == labels_batch) / float(labels_batch.size)

                metric_avg.update(loss.item(), acc)
            
        scheduler.step()
        train_loss, train_acc = metric_avg()
        print("- Train metrics: loss={:.2f}, acc={:.2f}".format(train_loss, train_acc))      
        
        

        # ---------- validate ------------
        model.eval()
        metric_avg = RunningAverage()

        for val_batch, labels_batch in val_dataloader:
            val_batch, labels_batch = val_batch.to(device), labels_batch.to(device)

            output_batch = model(val_batch)
            loss = criteria(output_batch, labels_batch)

            output_batch = output_batch.data.cpu().numpy()
            labels_batch = labels_batch.data.cpu().numpy()

            predict_labels = np.argmax(output_batch, axis=1)
            acc = np.sum(predict_labels == labels_batch) / float(labels_batch.size)

            metric_avg.update(loss.item(), acc)

        val_loss, val_acc = metric_avg()
        print("- Validate metrics: loss={:.2f}, acc={:.2f}".format(val_loss, val_acc))


        # ---------- Save weights ------------
        
        is_best = val_acc >= best_val_acc
        if is_best:
            print("- Found new best accuracy")
            best_val_acc = val_acc
        
            save_checkpoint({'epoch': epoch + 1,
                             'state_dict': model.state_dict(),
                             'optim_dict' : optimizer.state_dict()})


In [None]:
import torch.nn as nn
import torchvision.models as models

resnet50 = nn.Sequential(models.resnet50(weights=models.ResNet50_Weights.DEFAULT),
                         nn.ReLU(),
                         nn.Dropout(0.5),
                         nn.Linear(in_features=1000, out_features=5, bias=True))


resnet50.to(device)

In [None]:
import torchvision.models as models
import torch.nn.functional as F

criteria = nn.CrossEntropyLoss()
optimizer = optim.SGD(resnet50.parameters(), lr=lr, momentum=0.9)
scheduler = StepLR(optimizer, step_size=lr_step, gamma=0.1)

train_and_evaluate(resnet50, train_loader, val_loader, criteria, optimizer, scheduler, epochs, save_steps)