In [31]:
# import some needed libs
import torch
import torch.nn as nn
import torch.utils as utils
import os
from tqdm import tqdm
import os
from datetime import datetime
import pprint
import torchvision.transforms as transforms
from PIL import Image
import os
import logging
import numpy as np
import random

In [32]:
def create_logger(log_name, log_path, show_time=False):
    # create a logger
    logger = logging.getLogger(log_name)
    logger.setLevel(logging.INFO)
    logger.propagate = True
    # create a handler
    handler = logging.FileHandler(log_path)
    # create a formatter
    if show_time:
        formatter = logging.Formatter(fmt='%(asctime)s - %(message)s')
    else:
        formatter = logging.Formatter(fmt='%(message)s')
    # assemble them
    logger.addHandler(handler)
    handler.setFormatter(formatter)
    return logger

In [33]:
# set the random seeds
def same_seeds(seed): # 固定随机种子（CPU）
    torch.manual_seed(seed) # 固定随机种子（GPU)
    if torch.cuda.is_available():
        torch.cuda.manual_seed(seed) # 为当前GPU设置
        torch.cuda.manual_seed_all(seed)  # 为所有GPU设置
    np.random.seed(seed)  # 保证后续使用random函数时，产生固定的随机数
    torch.backends.cudnn.benchmark = False # GPU、网络结构固定，可设置为True
    torch.backends.cudnn.deterministic = True # 固定网络结构


# define the global seed
def all_seed(seed = 6666):
    """
    设置随机种子
    """
    np.random.seed(seed)
    random.seed(seed)
    # CPU
    torch.manual_seed(seed) 
    # GPU
    if torch.cuda.is_available():
        torch.cuda.manual_seed_all(seed)
        torch.cuda.manual_seed(seed) 
    # python 全局
    os.environ['PYTHONHASHSEED'] = str(seed) 
    # cudnn
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False
    torch.backends.cudnn.enabled = False
    print(f'Set env random_seed = {seed}')

In [None]:
# data pre-procession
def read_images(imgs_dir_root, data_tag='training', resize_shape=128):
    '''
    ---
    PARAS
    ---
    imgs_path: str, the file storaged the images
    data_tag: str, 'training', 'test', 'validation'

    ---
    RETURN
    --- 
    return a torch array that storages the images data

    '''
    imgs_dir_root = os.path.join(imgs_dir_root, data_tag)
    test_tfm = transforms.Compose([
        transforms.Resize((resize_shape, resize_shape)),
        transforms.ToTensor(),
    ])

    train_tfm = transforms.Compose([
        transforms.Resize((resize_shape, resize_shape)),
        # AutoAugment: Learning Augmentation Strategies from Data "<https://arxiv.org/pdf/1805.09501.pdf>"
        transforms.AutoAugment(transforms.AutoAugmentPolicy.IMAGENET),
        transforms.ToTensor(),
    ])

    images_list, labels_list, nimages_list = None, None, None
    if data_tag == 'training' or data_tag == 'validation':
        nimages_list = os.listdir(imgs_dir_root)
        # need a shuffle to the nimages_list
        labels_list = [int(nimage.split('_')[0]) for nimage in nimages_list]
        images_list = [train_tfm(Image.open(os.path.join(imgs_dir_root, nfile))) 
                      for nfile in nimages_list]

    elif data_tag == 'test':
        nimages_list = os.listdir(imgs_dir_root)
        images_list = [test_tfm(Image.open(os.path.join(imgs_dir_root, nfile))) 
                      for nfile in nimages_list]
        labels_list = None

    return images_list, labels_list, nimages_list

In [35]:
class Food_Dataset(utils.data.Dataset):
    def __init__(self, images_list, labels_list=None):
        super(Food_Dataset, self).__init__()
        self.images = images_list
        self.labels = labels_list

    def __len__(self):
        if self.labels is not None:
            return (len(self.images) == len(self.labels)) * len(self.images)
        return len(self.images)
    
    def __getitem__(self, index):
        if self.labels is not None:
            return self.images[index], self.labels[index]
        return self.images[index]

In [None]:
class BaseNet(nn.Module):
    def __init__(self):
        super(BaseNet, self).__init__()
        # input [3, 128, 128]
        self.cnn = nn.Sequential(
            nn.Conv2d(3, 64, 3, 1, 1),  # [64, 128, 128]
            nn.BatchNorm2d(64),
            nn.ReLU(),
            nn.MaxPool2d(2, 2, 0),      # [64, 64, 64]

            nn.Conv2d(64, 128, 3, 1, 1), # [128, 64, 64]
            nn.BatchNorm2d(128),
            nn.ReLU(),
            nn.MaxPool2d(2, 2, 0),      # [128, 32, 32]

            nn.Conv2d(128, 256, 3, 1, 1), # [256, 32, 32]
            nn.BatchNorm2d(256),
            nn.ReLU(),
            nn.MaxPool2d(2, 2, 0),      # [256, 16, 16]

            nn.Conv2d(256, 512, 3, 1, 1), # [512, 16, 16]
            nn.BatchNorm2d(512),
            nn.ReLU(),
            nn.MaxPool2d(2, 2, 0),       # [512, 8, 8]
            
            nn.Conv2d(512, 512, 3, 1, 1), # [512, 8, 8]
            nn.BatchNorm2d(512),
            nn.ReLU(),
            nn.MaxPool2d(2, 2, 0),       # [512, 4, 4]
        )
        self.fc = nn.Sequential(
            nn.Linear(512*4*4, 1024),
            nn.ReLU(),
            nn.Dropout(0.5),
            nn.Linear(1024, 512),
            nn.ReLU(),
            nn.Dropout(0.5),
            nn.Linear(512, 11)
        )

    def forward(self, x):
        out = self.cnn(x)
        out = out.view(out.size()[0], -1)
        return self.fc(out)

In [37]:
def train(train_dataloader, val_dataloader, config, model, device, logger_train, logger_val, pre_train):
    add_epoch, add_step = pre_train
    criterion = nn.CrossEntropyLoss(ignore_index=-1)
    optimizer = torch.optim.Adam(model.parameters(), 
                                 lr = config['lr'], 
                                 weight_decay=config['weight_decay'])
    # optimizer = torch.optim.SGD(model.parameters(), lr=config['lr'])
    best_acc = 0.0
    early_stop_count = 0
    step = 1

    for epoch in range(config['n_epoch']):

        # strat training
        model.train()
        train_loss_list, train_acc_list = [], []
        for imgs, labels in tqdm(train_dataloader):
            optimizer.zero_grad()
            imgs = imgs.to(device)
            labels = labels.to(device)
            outputs = model(imgs)
            loss = criterion(outputs, labels)
            loss.backward()
            # 稳定训练的技巧
            if config['clip_flag']:
                grad_norm = nn.utils.clip_grad_norm_(model.parameters(), max_norm=10)
            
            optimizer.step()
            _, train_pred = torch.max(outputs, dim=1) # find the biggest value along the line
            train_acc_list.append((train_pred.detach() == labels.detach()).sum().item())
            train_loss_list.append(loss.item())
            if step % 10 == 0:
                logger_train.info(f'step {add_step+step:03d} train loss {loss.item():3.6f} train acc {((train_pred.detach() == labels.detach()).sum().item()) / len(imgs):3.6f}')
            step += 1
        
        # record the training info
        train_loss = sum(train_loss_list) / len(train_loss_list)
        train_acc = sum(train_acc_list) / (config['batch_size'] * len(train_acc_list))
        logger_train.info(f"EPOCH [{add_epoch+epoch+1:03d}|{add_epoch+config['n_epoch']:03d}] TRAIN LOSS {train_loss:3.6f} TRAIN ACC {train_acc:3.6f}")
        print(f"EPOCH [{add_epoch+epoch+1:03d}|{add_epoch+config['n_epoch']:03d}] TRAIN LOSS {train_loss:3.6f} TRAIN ACC {train_acc:3.6f}")


        # start valiation
        model.eval()
        val_acc_list, val_loss_list = [], []
        with torch.no_grad():
            for imgs, labels in tqdm(val_dataloader):
                imgs = imgs.to(device)
                labels = labels.to(device)
                outputs = model(imgs)
                loss = criterion(outputs, labels)
                _, val_pred = torch.max(outputs, dim = 1)
                val_acc_list.append((val_pred.detach() == labels.detach()).sum().item())
                val_loss_list.append(loss.item())
        
        # record the val info
        val_loss = sum(val_loss_list) / len(val_loss_list)
        val_acc = sum(val_acc_list) / (config['batch_size'] * len(val_acc_list))
        logger_val.info(f"EPOCH [{add_epoch+epoch+1:03d}|{add_epoch+config['n_epoch']:03d}] VAL LOSS {val_loss:3.6f} VAL ACC {val_acc:3.6f}")
        print(f"EPOCH [{add_epoch+epoch+1:03d}|{add_epoch+config['n_epoch']:03d}] VAL LOSS {val_loss:3.6f} VAL ACC {val_acc:3.6f}")

        # save the best model
        if val_acc > best_acc:
            best_acc = val_acc
            torch.save(model.state_dict(), f"{config['checkpoints_dir']}/best_checkpoint.pt")
            print(f'saving the best model with val acc: {best_acc:3.6f}')
            early_stop_count = 0
        else:
            early_stop_count += 1

        # early stop
        if early_stop_count >= config['early_stop']:
            print('\nModel is not improving, so halt the training session.')
            
            
    # save the latest model for belowing training     
    torch.save(model.state_dict(), f"{config['checkpoints_dir']}/latest_checkpoint.pt")
    print(f"saving the latest model...")

In [38]:
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
device = 'cuda' if torch.cuda.is_available() else 'cpu'
    
    
config = {
    # =============================================
    'data_dir':'./data/food11',
    'checkpoints_dir':f'./checkpoints/{timestamp}',
    'loggers_dir':f'./loggers/{timestamp}',
    # =============================================
    'seed': 6666,
    'n_epoch': 100,
    'batch_size': 128,
    'lr':0.0005,
    'weight_decay':1e-3,
    'early_stop': 25,
    'num_workers': 0,
    'clip_flag': True,
    # =============================================
}
all_seed(config['seed'])
if not os.path.exists(config['checkpoints_dir']):
    os.makedirs(config['checkpoints_dir'], mode=0o755)
if not os.path.exists(config['loggers_dir']):
    os.makedirs(config['loggers_dir'], mode=0o755)
train_images_list, train_labels_list, _ = read_images(config['data_dir'], 'training', resize_shape=128)
val_images_list, val_labels_list, _ = read_images(config['data_dir'], 'validation', resize_shape=128)
train_set = Food_Dataset(train_images_list, train_labels_list)
val_set = Food_Dataset(val_images_list, val_labels_list)
del train_images_list, train_labels_list, val_images_list, val_labels_list
train_dataloader = utils.data.DataLoader(train_set, batch_size=config['batch_size'], 
                                        shuffle=True, num_workers=config['num_workers'], 
                                        pin_memory=True, drop_last=True)
val_dataloader = utils.data.DataLoader(val_set, batch_size=config['batch_size'], 
                                        shuffle=True, num_workers=config['num_workers'], 
                                        pin_memory=True, drop_last=True)
print(f"DEVICE:{device}")
my_model = BaseNet().to(device)
# my_model = ResNet().to(device)
# from torchvision.models import resnet50
# my_model = resnet50(pretrained=True).to(device)
# time_tag = 20250827_174759
# my_model.load_state_dict(torch.load(f"checkpoints/20250827_174759/latest_checkpoint.pt", map_location=device))
    
train_logger = create_logger('train_logger', f"{config['loggers_dir']}/train.log", show_time=False)
val_logger = create_logger('val_logger', f"{config['loggers_dir']}/val.log", show_time=False)
note_logger = create_logger('note_logger', f"{config['loggers_dir']}/note.log", show_time=True)
note_logger.info(f'DEVICE:\n{device}')
note_logger.info(f'MODEL INFO:\n{str(my_model)}')
note_logger.info(f'CONFIG INFO:\n{pprint.pformat(config, indent=4)}')
train(train_dataloader, val_dataloader, config, my_model, device, train_logger, val_logger, pre_train=[0, 0])

Set env random_seed = 6666
DEVICE:cuda


100%|██████████| 77/77 [01:17<00:00,  1.01s/it]


EPOCH [001|100] TRAIN LOSS 2.273899 TRAIN ACC 0.194095


100%|██████████| 26/26 [00:02<00:00, 10.04it/s]


EPOCH [001|100] VAL LOSS 2.222982 VAL ACC 0.192909
saving the best model with val acc: 0.192909


100%|██████████| 77/77 [01:17<00:00,  1.01s/it]


EPOCH [002|100] TRAIN LOSS 2.100618 TRAIN ACC 0.265219


100%|██████████| 26/26 [00:02<00:00, 10.01it/s]


EPOCH [002|100] VAL LOSS 2.090106 VAL ACC 0.268630
saving the best model with val acc: 0.268630


100%|██████████| 77/77 [01:17<00:00,  1.01s/it]


EPOCH [003|100] TRAIN LOSS 1.979142 TRAIN ACC 0.319298


100%|██████████| 26/26 [00:02<00:00,  9.92it/s]


EPOCH [003|100] VAL LOSS 2.014100 VAL ACC 0.320913
saving the best model with val acc: 0.320913


100%|██████████| 77/77 [01:18<00:00,  1.02s/it]


EPOCH [004|100] TRAIN LOSS 1.865118 TRAIN ACC 0.350548


100%|██████████| 26/26 [00:02<00:00, 10.05it/s]


EPOCH [004|100] VAL LOSS 2.003217 VAL ACC 0.294171


100%|██████████| 77/77 [01:19<00:00,  1.03s/it]


EPOCH [005|100] TRAIN LOSS 1.760790 TRAIN ACC 0.386567


100%|██████████| 26/26 [00:02<00:00,  9.47it/s]


EPOCH [005|100] VAL LOSS 1.857027 VAL ACC 0.354567
saving the best model with val acc: 0.354567


100%|██████████| 77/77 [01:19<00:00,  1.03s/it]


EPOCH [006|100] TRAIN LOSS 1.658817 TRAIN ACC 0.424513


100%|██████████| 26/26 [00:02<00:00,  9.91it/s]


EPOCH [006|100] VAL LOSS 1.734964 VAL ACC 0.408353
saving the best model with val acc: 0.408353


100%|██████████| 77/77 [01:17<00:00,  1.01s/it]


EPOCH [007|100] TRAIN LOSS 1.566923 TRAIN ACC 0.452821


100%|██████████| 26/26 [00:02<00:00, 10.02it/s]


EPOCH [007|100] VAL LOSS 1.790058 VAL ACC 0.388221


100%|██████████| 77/77 [01:17<00:00,  1.01s/it]


EPOCH [008|100] TRAIN LOSS 1.490121 TRAIN ACC 0.482346


100%|██████████| 26/26 [00:02<00:00, 10.05it/s]


EPOCH [008|100] VAL LOSS 1.801004 VAL ACC 0.378606


100%|██████████| 77/77 [01:17<00:00,  1.01s/it]


EPOCH [009|100] TRAIN LOSS 1.412630 TRAIN ACC 0.508929


100%|██████████| 26/26 [00:02<00:00, 10.08it/s]


EPOCH [009|100] VAL LOSS 1.764232 VAL ACC 0.413462
saving the best model with val acc: 0.413462


100%|██████████| 77/77 [01:18<00:00,  1.01s/it]


EPOCH [010|100] TRAIN LOSS 1.339168 TRAIN ACC 0.534903


100%|██████████| 26/26 [00:02<00:00,  9.96it/s]


EPOCH [010|100] VAL LOSS 1.729841 VAL ACC 0.420373
saving the best model with val acc: 0.420373


100%|██████████| 77/77 [01:18<00:00,  1.01s/it]


EPOCH [011|100] TRAIN LOSS 1.260377 TRAIN ACC 0.561688


100%|██████████| 26/26 [00:02<00:00, 10.00it/s]


EPOCH [011|100] VAL LOSS 2.411820 VAL ACC 0.318810


100%|██████████| 77/77 [01:17<00:00,  1.01s/it]


EPOCH [012|100] TRAIN LOSS 1.178190 TRAIN ACC 0.593344


100%|██████████| 26/26 [00:02<00:00, 10.01it/s]


EPOCH [012|100] VAL LOSS 1.600648 VAL ACC 0.460938
saving the best model with val acc: 0.460938


100%|██████████| 77/77 [01:17<00:00,  1.01s/it]


EPOCH [013|100] TRAIN LOSS 1.084322 TRAIN ACC 0.621652


100%|██████████| 26/26 [00:02<00:00, 10.03it/s]


EPOCH [013|100] VAL LOSS 1.635768 VAL ACC 0.478365
saving the best model with val acc: 0.478365


100%|██████████| 77/77 [01:17<00:00,  1.01s/it]


EPOCH [014|100] TRAIN LOSS 1.004398 TRAIN ACC 0.655438


100%|██████████| 26/26 [00:02<00:00, 10.03it/s]


EPOCH [014|100] VAL LOSS 1.841349 VAL ACC 0.442007


100%|██████████| 77/77 [01:18<00:00,  1.01s/it]


EPOCH [015|100] TRAIN LOSS 0.936128 TRAIN ACC 0.675629


100%|██████████| 26/26 [00:02<00:00, 10.11it/s]


EPOCH [015|100] VAL LOSS 2.238690 VAL ACC 0.390625


100%|██████████| 77/77 [01:19<00:00,  1.03s/it]


EPOCH [016|100] TRAIN LOSS 0.836663 TRAIN ACC 0.713981


100%|██████████| 26/26 [00:02<00:00,  9.81it/s]


EPOCH [016|100] VAL LOSS 1.776462 VAL ACC 0.472957


100%|██████████| 77/77 [01:19<00:00,  1.03s/it]


EPOCH [017|100] TRAIN LOSS 0.739967 TRAIN ACC 0.742188


100%|██████████| 26/26 [00:02<00:00,  9.28it/s]


EPOCH [017|100] VAL LOSS 2.147530 VAL ACC 0.443510


  6%|▋         | 5/77 [00:06<01:28,  1.23s/it]


KeyboardInterrupt: 