# Train Model
## Written By KYLiN

In [21]:
import torch
import torchvision
import torchvision.transforms as transforms
import torch.nn as nn
import torch.optim as optim
from torchvision.datasets import ImageFolder
from torch.utils.data import DataLoader , random_split
from torch.utils.data.sampler import WeightedRandomSampler
from torchvision.models import mobilenet_v3_large , MobileNet_V3_Large_Weights
from sklearn.metrics import f1_score
import pandas as pd 

# speed up 
from torch.cuda.amp import GradScaler , autocast

# other sampler
from torchsampler import ImbalancedDatasetSampler


from rich import print
from tqdm import tqdm
import os
from time import time

In [22]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(device)

In [23]:
# 训练数据的 transforms
transform_train = transforms.Compose([
    transforms.RandomResizedCrop(size=224, scale=(0.8, 1.0)),
    transforms.RandomAffine(degrees=0,translate=(0.05,0.05)),
    transforms.RandomHorizontalFlip(),
    transforms.RandomRotation(degrees=15),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])

# 测试数据的 transforms
transform_test = transforms.Compose([
    transforms.Resize(256),
    transforms.CenterCrop(224),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])

In [24]:
dataset_path = "./database"
dataset = ImageFolder(dataset_path)

dataset_size = len(dataset)
train_size = int(0.99999998 * dataset_size)
val_size = dataset_size - train_size

print(f"dataset size: {dataset_size}, train size: {train_size}, val size: {val_size}")

In [25]:
dataset.classes

['-10_1', '19_35', '2_6', '36_100', '7_18']

In [26]:
classes_length = len(dataset.classes)
classes_length

5

In [27]:
train_dataset , val_dataset = random_split(dataset , [train_size , val_size])

train_dataset.dataset.transform = transform_train
val_dataset.dataset.transform = transform_test


In [28]:
TRAIN_BATCH_SIZE = 32
TEST_BATCH_SIZE = 32

In [29]:
train_loader = DataLoader(train_dataset,
                        #   sampler=WeightedRandomSampler([boo_weight , hot_weight], num_samples=len(dataset) , replacement=True), # include shuffle=True
                          batch_size=TRAIN_BATCH_SIZE,
                          shuffle=True,
                          )

test_loader = DataLoader(val_dataset,
                         batch_size=TEST_BATCH_SIZE,
                         shuffle=True,
                         )

In [30]:
# MODEL_PATH = os.path.join("./model" , "mobileNet_v3_test_v6.pth")
to_model_path = lambda epoch_num , type_:os.path.join("./model" , f"mobileNet_v3_v8_{epoch_num}_{type_}.pth")

model = mobilenet_v3_large(weights=MobileNet_V3_Large_Weights.IMAGENET1K_V2)

num_features = model.classifier[-1].in_features
# output only two class 
model.classifier[-1] = nn.Linear(num_features , 5)

model = model.to(device=device)
# print(model)

In [31]:
# loss function and optimizer
criterion = nn.CrossEntropyLoss() # weight=sample_weight.to(device)

optimizer = optim.Adam(model.parameters() , lr=3e-4 ,  weight_decay=0.0001)

In [32]:

TRAIN_EPOCH = 30
train_data_record_df = pd.DataFrame({
    "epoch":[] , 
    "time(mins)": [] , 
    "loss" : [] , 
    "train_acc" :[],
    "train_f1":[] , 
    "test_acc": [] , 
    "test_f1":[] , 
    "update":[],
    "File Name" : [],
})
train_data_record_df

Unnamed: 0,epoch,time(mins),loss,train_acc,train_f1,test_acc,test_f1,update,File Name


In [33]:
scaler = GradScaler()
# old_test_acc = -1…
old_test_acc , old_f1_mark = -1 , -1
for epoch in range(TRAIN_EPOCH):
    print(f"Epoch: [{epoch}]")
    start_time = time()
    
    
    train_acc , test_acc = 0 , 0 
    # training
    model.train()
    
    train_pred_list , train_labels_list = [] , []
    test_pred_list , test_labels_list = [] , []
    
    
    with tqdm(train_loader , unit="batch" , desc="Training...") as t_epoch:
        for inputs , labels in t_epoch:
            
            # in cuda
            torch.cuda.empty_cache()
            inputs , labels_gpu = inputs.to(device) , labels.to(device)
            # zero the parameter gradients
            optimizer.zero_grad()
            
            # forward + backward + optimize
            with autocast():
                model_outputs = model(inputs)
                
                loss = criterion(model_outputs.float() , labels_gpu)
                
            # use scaler update  
            scaler.scale(loss).backward()
            scaler.step(optimizer)
            scaler.update()
                
            
            # in cpu
            model_outputs = model_outputs.cpu()
            # dim is one , get array 
            train_pred = torch.max(model_outputs , 1).indices
            # how many is same 
            train_acc += int(torch.sum(train_pred == labels))
            # for f1
            train_labels_list.extend(labels.tolist())
            train_pred_list.extend(train_pred.tolist())
        
        # get epoch train acc 
        ep_train_acc = train_acc / train_size
        
    train_f1 = f1_score(train_labels_list , train_pred_list ,average="micro")
    
    # lock model 
    model.eval()
    with torch.no_grad():
        
        # validation
        with tqdm(test_loader , unit="batch" , desc="Testing...") as test_epoch:
            for inputs , labels in test_epoch:
                # in cuda
                torch.cuda.empty_cache()
                inputs , labels_gpu = inputs.to(device) , labels.to(device)
                test_prob = model(inputs)
                
                # in cpu
                test_prob = test_prob.cpu()
                test_pred = torch.max(test_prob , 1).indices
                test_acc += int(torch.sum(test_pred == labels))
                
                test_labels_list.extend(labels.tolist())
                test_pred_list.extend(test_pred.tolist())  
                
        ep_test_acc = test_acc / val_size
        # ep_test_acc = test_acc / len(dataset)
    test_f1 = f1_score(test_labels_list , test_pred_list,average="micro")         
    
    end_time = time()
    duration = (end_time - start_time) / 60
    print(f"Time: {duration}, Loss: {loss:.2f}\nTrain_acc: {ep_train_acc*100 :.2f}, Test_acc: {ep_test_acc*100 :.2f}")
    print(f"Train f1:{train_f1:.3f}, Test f1: {test_f1:.3f}")
    
    to_df_data = {
        "epoch":epoch , 
        "time(mins)": duration , 
        "loss" : loss.cpu().detach().numpy() , 
        "train_acc" : ep_train_acc*100.0,
        "train_f1":train_f1 , 
        "test_acc": ep_test_acc*100 , 
        "test_f1":test_f1 , 
        "update":False,
        "File Name" : "",
    }
    
    check_condition = (ep_train_acc > old_test_acc, train_f1 > old_f1_mark)
    
    if check_condition[0] or check_condition[1]:
        
        if check_condition[0]:
            type_str = "ACC"
            old_test_acc = ep_train_acc
            
        if check_condition[1]:
            type_str = "F1"
            old_f1_mark = train_f1
        
        if check_condition[0] and check_condition[1]:
            type_str = "BOTH"
        
        model_path = to_model_path(epoch , type_str)
        torch.save(model.state_dict() , model_path)
        old_f1_mark = train_f1
        print(f"update new model, new acc: {ep_train_acc*100 :.2f}, new f1: {train_f1:.3f} , save in {model_path}")
        to_df_data["update"] = True
        to_df_data["File Name"] = model_path
        
    
    train_data_record_df.loc[len(train_data_record_df)] = to_df_data 
    train_data_record_df.to_csv("./table/mobileNet_v3_v8_records.csv")
    
    

    
    

Training...:   0%|          | 0/1738 [00:00<?, ?batch/s]

Training...: 100%|██████████| 1738/1738 [24:39<00:00,  1.17batch/s]
Testing...: 100%|██████████| 1/1 [00:00<00:00,  5.92batch/s]


Training...: 100%|██████████| 1738/1738 [25:45<00:00,  1.12batch/s]
Testing...: 100%|██████████| 1/1 [00:00<00:00, 27.78batch/s]


Training...: 100%|██████████| 1738/1738 [23:52<00:00,  1.21batch/s]
Testing...: 100%|██████████| 1/1 [00:00<00:00, 29.68batch/s]


Training...: 100%|██████████| 1738/1738 [24:30<00:00,  1.18batch/s]
Testing...: 100%|██████████| 1/1 [00:00<00:00, 24.28batch/s]


Training...: 100%|██████████| 1738/1738 [23:56<00:00,  1.21batch/s]
Testing...: 100%|██████████| 1/1 [00:00<00:00, 31.25batch/s]


Training...: 100%|██████████| 1738/1738 [23:18<00:00,  1.24batch/s]
Testing...: 100%|██████████| 1/1 [00:00<00:00, 24.39batch/s]


Training...: 100%|██████████| 1738/1738 [22:38<00:00,  1.28batch/s]
Testing...: 100%|██████████| 1/1 [00:00<00:00, 27.18batch/s]


Training...: 100%|██████████| 1738/1738 [22:39<00:00,  1.28batch/s]
Testing...: 100%|██████████| 1/1 [00:00<00:00, 27.59batch/s]


Training...: 100%|██████████| 1738/1738 [22:32<00:00,  1.29batch/s]
Testing...: 100%|██████████| 1/1 [00:00<00:00, 32.26batch/s]


Training...: 100%|██████████| 1738/1738 [22:35<00:00,  1.28batch/s]
Testing...: 100%|██████████| 1/1 [00:00<00:00, 33.34batch/s]


Training...: 100%|██████████| 1738/1738 [22:34<00:00,  1.28batch/s]
Testing...: 100%|██████████| 1/1 [00:00<00:00, 26.32batch/s]


Training...: 100%|██████████| 1738/1738 [22:30<00:00,  1.29batch/s]
Testing...: 100%|██████████| 1/1 [00:00<00:00, 32.26batch/s]


Training...: 100%|██████████| 1738/1738 [22:48<00:00,  1.27batch/s]
Testing...: 100%|██████████| 1/1 [00:00<00:00, 29.41batch/s]


Training...: 100%|██████████| 1738/1738 [22:35<00:00,  1.28batch/s]
Testing...: 100%|██████████| 1/1 [00:00<00:00, 29.41batch/s]


Training...: 100%|██████████| 1738/1738 [22:39<00:00,  1.28batch/s]
Testing...: 100%|██████████| 1/1 [00:00<00:00, 27.78batch/s]


Training...: 100%|██████████| 1738/1738 [22:46<00:00,  1.27batch/s]
Testing...: 100%|██████████| 1/1 [00:00<00:00, 31.25batch/s]


Training...: 100%|██████████| 1738/1738 [22:37<00:00,  1.28batch/s]
Testing...: 100%|██████████| 1/1 [00:00<00:00, 31.25batch/s]


Training...: 100%|██████████| 1738/1738 [22:39<00:00,  1.28batch/s]
Testing...: 100%|██████████| 1/1 [00:00<00:00, 32.26batch/s]


Training...: 100%|██████████| 1738/1738 [22:49<00:00,  1.27batch/s]
Testing...: 100%|██████████| 1/1 [00:00<00:00, 30.30batch/s]


Training...: 100%|██████████| 1738/1738 [22:40<00:00,  1.28batch/s]
Testing...: 100%|██████████| 1/1 [00:00<00:00, 32.26batch/s]


Training...: 100%|██████████| 1738/1738 [22:56<00:00,  1.26batch/s]
Testing...: 100%|██████████| 1/1 [00:00<00:00, 27.78batch/s]


Training...: 100%|██████████| 1738/1738 [22:44<00:00,  1.27batch/s]
Testing...: 100%|██████████| 1/1 [00:00<00:00, 30.30batch/s]


Training...: 100%|██████████| 1738/1738 [22:47<00:00,  1.27batch/s]
Testing...: 100%|██████████| 1/1 [00:00<00:00, 28.57batch/s]


Training...: 100%|██████████| 1738/1738 [22:42<00:00,  1.28batch/s]
Testing...: 100%|██████████| 1/1 [00:00<00:00, 32.26batch/s]


Training...: 100%|██████████| 1738/1738 [22:42<00:00,  1.28batch/s]
Testing...: 100%|██████████| 1/1 [00:00<00:00, 28.57batch/s]


Training...: 100%|██████████| 1738/1738 [22:48<00:00,  1.27batch/s]
Testing...: 100%|██████████| 1/1 [00:00<00:00, 29.41batch/s]


Training...: 100%|██████████| 1738/1738 [22:45<00:00,  1.27batch/s]
Testing...: 100%|██████████| 1/1 [00:00<00:00, 33.33batch/s]


Training...: 100%|██████████| 1738/1738 [22:42<00:00,  1.28batch/s]
Testing...: 100%|██████████| 1/1 [00:00<00:00, 33.33batch/s]


Training...: 100%|██████████| 1738/1738 [22:58<00:00,  1.26batch/s]
Testing...: 100%|██████████| 1/1 [00:00<00:00, 29.41batch/s]


Training...: 100%|██████████| 1738/1738 [23:17<00:00,  1.24batch/s]
Testing...: 100%|██████████| 1/1 [00:00<00:00, 27.78batch/s]


In [3]:
import pandas as pd 

In [4]:
train_table_result = pd.read_csv("./table/mobileNet_v3_v8_records.csv")
train_table_result

Unnamed: 0.1,Unnamed: 0,epoch,time(mins),loss,train_acc,train_f1,test_acc,test_f1,update,File Name
0,0,0,24.668357,1.342858,32.183267,0.321833,0.0,0.0,True,./model\mobileNet_v3_v8_0_BOTH.pth
1,1,1,25.756044,1.336178,37.331582,0.373316,100.0,1.0,True,./model\mobileNet_v3_v8_1_BOTH.pth
2,2,2,23.875603,1.281872,41.828713,0.418287,0.0,0.0,True,./model\mobileNet_v3_v8_2_BOTH.pth
3,3,3,24.510022,1.101552,47.218075,0.472181,0.0,0.0,True,./model\mobileNet_v3_v8_3_BOTH.pth
4,4,4,23.938384,1.505842,54.773255,0.547733,0.0,0.0,True,./model\mobileNet_v3_v8_4_BOTH.pth
5,5,5,23.316174,0.978088,62.715188,0.627152,100.0,1.0,True,./model\mobileNet_v3_v8_5_BOTH.pth
6,6,6,22.643145,0.787612,70.725477,0.707255,0.0,0.0,True,./model\mobileNet_v3_v8_6_BOTH.pth
7,7,7,22.663777,0.867645,77.723013,0.77723,100.0,1.0,True,./model\mobileNet_v3_v8_7_BOTH.pth
8,8,8,22.538653,0.995355,82.455793,0.824558,100.0,1.0,True,./model\mobileNet_v3_v8_8_BOTH.pth
9,9,9,22.59942,0.476034,85.45448,0.854545,0.0,0.0,True,./model\mobileNet_v3_v8_9_BOTH.pth
