In [15]:
import os
import sys
import time
import copy
import torch
import torch.nn as nn
from torchvision import transforms, datasets, utils
import matplotlib.pyplot as plt
import numpy as np
import torch.optim as optim
from ST_tools import *

In [10]:
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

## AlexNet

In [None]:
class AlexNet(nn.Module):

    def __init__(self, num_classes=1000, init_weights=False):
        super(AlexNet, self).__init__()
        self.features = nn.Sequential(
            nn.Conv2d(3, 48, kernel_size=11, stride=4,
                      padding=2),  
            nn.ReLU(inplace=True),
            nn.MaxPool2d(kernel_size=3, stride=2),
            nn.Conv2d(48, 128, kernel_size=5,
                      padding=2),  
            nn.ReLU(inplace=True),
            nn.MaxPool2d(kernel_size=3, stride=2),
            nn.Conv2d(128, 192, kernel_size=3,
                      padding=1),  
            nn.ReLU(inplace=True),
            nn.Conv2d(192, 192, kernel_size=3,
                      padding=1),  
            nn.ReLU(inplace=True),
            nn.Conv2d(192, 128, kernel_size=3,
                      padding=1),  
            nn.ReLU(inplace=True),
            nn.MaxPool2d(kernel_size=3, stride=2),
        )
        self.classifier = nn.Sequential(
            nn.Dropout(p=0.5),
            nn.Linear(128 * 6 * 6, 2048),
            nn.ReLU(inplace=True),
            nn.Dropout(p=0.5),
            nn.Linear(2048, 2048),
            nn.ReLU(inplace=True),
            nn.Linear(2048, num_classes),
        )
        if init_weights:
            self._initialize_weights()

    def forward(self, x):
        x = self.features(x)
        x = torch.flatten(x, start_dim=1)
        x = self.classifier(x)
        return x
model = AlexNet(num_classes=xxxx, init_weights=True).to(device)

## ResNet

In [None]:
model = models.resnet50(pretrained=True)
num_frts = model.fc.in_features
model.fc = nn.Sequential(nn.Linear(num_frts,6),
                               nn.LogSoftmax(dim=1))
model=model.to(device)

## Swin transformer

In [13]:
def get_model(num_classes: int = 1000, **kwargs):
    model = SwinTransformer(in_chans=3,
                            patch_size=4,
                            window_size=7,
                            embed_dim=96,
                            depths=(2, 2, 6, 2),
                            num_heads=(3, 6, 12, 24),
                            num_classes=num_classes,
                            **kwargs)
    model_weight_path = "./xxxxxx.pth"#pretrain model
    weights_dict = torch.load(model_weight_path, map_location=device)["model"]
    for k in list(weights_dict.keys()):
            del weights_dict[k]
    model.load_state_dict(weights_dict, strict=False)
    return model
model = get_model(num_classes=xxxx).to(device)

In [None]:
data_dir = './xxxx/'
train_dir = data_dir + '/train'
valid_dir = data_dir + '/valid'

In [None]:
data_transforms = {
    'train':
    transforms.Compose([
        transforms.RandomRotation(45),
        transforms.Resize(256),
        transforms.RandomResizedCrop(224),
        transforms.RandomHorizontalFlip(),
        transforms.ToTensor(),
        transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
    ]),
    'valid':
    transforms.Compose([
        transforms.Resize(256),
        transforms.CenterCrop(224),
        transforms.ToTensor(),
        transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
    ]),
    'test':
    transforms.Compose([
        transforms.Resize(256),
        transforms.CenterCrop(224),
        transforms.ToTensor(),
        transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
    ])
}

In [None]:
batch_size = 64

In [None]:
image_datasets = {
    x: datasets.ImageFolder(os.path.join(data_dir, x), data_transforms[x])
    for x in ['train', 'valid']
}
dataloaders = {
    x: torch.utils.data.DataLoader(image_datasets[x],
                                   batch_size=batch_size,
                                   shuffle=True)
    for x in ['train', 'valid']
}

In [None]:
filename='xxxx.pth'

In [None]:
criterion = torch.nn.CrossEntropyLoss()
pg = [p for p in model.parameters() if p.requires_grad]
optimizer = optim.AdamW(pg, lr=xxxxx, weight_decay=5E-2)

In [None]:
val_acc_history = []
train_acc_history = []
train_losses = []
valid_losses = []

In [None]:
def train_model(model,
                dataloaders,
                criterion,
                optimizer,
                num_epochs=100,
                filename=filename):
    since = time.time()
    best_acc = 0
    LRs = [optimizer.param_groups[0]["lr"]]
    best_model_wts = copy.deepcopy(model.state_dict())
    for epoch in range(num_epochs):
        print('Epoch{}/{}'.format(epoch, num_epochs - 1))
        for phase in ['train', 'valid']:
            if phase == 'train':
                model.train()
            else:
                model.eval()
            running_loss = 0.0
            running_corrects = 0
            for inputs, labels in dataloaders[phase]:
                inputs = inputs.to(device)
                labels = labels.to(device)
                optimizer.zero_grad()
                with torch.set_grad_enabled(phase == 'train'):
                    outputs = model(inputs)
                    loss = criterion(outputs, labels)
                    _, preds = torch.max(outputs, 1)
                    if phase == 'train':
                        loss.backward()
                        optimizer.step()
                running_loss += loss.item() * inputs.size(0)
                running_corrects += torch.sum(preds == labels.data)
            epoch_loss = running_loss / len(dataloaders[phase].dataset)
            epoch_acc = running_corrects.double() / len(
                dataloaders[phase].dataset)
            time_elapsed = time.time() - since
            print('Time elapsed{:.0f}m{:.0f}s'.format(time_elapsed // 60,
                                                      time_elapsed % 60))
            print('{}Loss:{:.4f} Acc:{:.4f}'.format(phase, epoch_loss,
                                                    epoch_acc))
            if phase == 'valid' and epoch_acc > best_acc:
                best_acc = epoch_acc
                best_model_wts = copy.deepcopy(model.state_dict())
                state = {
                    'state_dict': model.state_dict(),
                    'best_acc': best_acc,
                    'optimizer': optimizer.state_dict(),
                }
                torch.save(state, filename)
            if phase == 'valid':
                val_acc_history.append(epoch_acc)
                valid_losses.append(epoch_loss)
            if phase == 'train':
                train_acc_history.append(epoch_acc)
                train_losses.append(epoch_loss)
        print('Optimizer learning rate:{:.7f}'.format(
            optimizer.param_groups[0]['lr']))
        LRs.append(optimizer.param_groups[0]['lr'])
    time_elapsed = time.time() - since
    print('Training complete in {:.0f}m {:.0f}s'.format(
        time_elapsed // 60, time_elapsed % 60))
    model.load_state_dict(best_model_wts)
    return model, val_acc_history, train_acc_history, valid_losses, train_losses, LRs

In [None]:
model,
val_acc_history,
train_acc_history,
valid_losses,
train_losses,
LRs = train_model(model, dataloaders, criterion, optimizer, num_epochs=100)