In [29]:
%load_ext autoreload
%autoreload 1

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [30]:
import os
from time import sleep

import matplotlib.pyplot as plt
import cv2
import numpy as np
import random
import torch
import torch.nn as nn
import torch.nn.functional as F
from torchvision import transforms
from torch.utils.data import DataLoader, Dataset
import PIL.Image as Image

from dataset import *
from utilities import *
from resnet import *
from config import *

In [31]:
color_to_gray_map, gray_to_color_map = None, None

In [32]:
np.random.seed(0)
train_imageset_path = '../trainval/DAVIS/ImageSets/2017/train.txt'
val_imageset_path = '../trainval/DAVIS/ImageSets/2017/val.txt'
testd_imageset_path = '../testd/DAVIS/ImageSets/2017/test-dev.txt'
trainval_image_root = '../trainval/DAVIS/JPEGImages/480p/'
trainval_mask_root = '../trainval/DAVIS/Annotations/480p/'
testd_image_root = '../testd/DAVIS/JPEGImages/480p/'
testd_mask_root = '../testd/DAVIS/Annotations/480p/'
models_root = '../models/'

train_list = []
val_list = []
test_list = []

with open(train_imageset_path, 'r') as f:
    for line in f:
        train_list.append(line.strip())
with open(val_imageset_path, 'r') as f:
    for line in f:
        val_list.append(line.strip())
with open(testd_imageset_path, 'r') as f:
    for line in f:
        test_list.append(line.strip())


In [33]:
import logging
# 设置日志
logging.basicConfig(filename='training_log.log',  # 日志文件名
                    filemode='a',  # 文件模式，'a' 表示追加
                    level=logging.INFO,  # 日志级别
                    format='%(asctime)s - %(levelname)s - %(message)s')  # 日志格式
def train(image_root, mask_root, target_list):
    # 初始化存储每个目标的最佳精确度的列表
    all_best_accuracies = []
    logging.info("Start of the program")
    
    for t in range(len(target_list)):
        # if t != 0:
        #     continue
        print(target_list[t])
        logging.info(target_list[t])
        image_path = os.path.join(image_root, target_list[t] + '/00000.jpg')
        mask_path = os.path.join(mask_root, target_list[t] + '/00000.png')
        model_save_path = os.path.join(models_root, target_list[t] + '.pt')
    
        image = cv2.imread(image_path)
        mask = cv2.imread(mask_path)
        PIL_mask = Image.open(mask_path)
        color_to_gray_map, gray_to_color_map = get_map(mask, PIL_mask)
        del PIL_mask
        
        image = cv2.resize(image, Resize, interpolation=cv2.INTER_NEAREST)
        mask = cv2.resize(mask, Resize, interpolation=cv2.INTER_NEAREST)
        mask = convert_to_gray_mask(mask, color_to_gray_map)
        print('type_cnt:', len(color_to_gray_map))
        logging.info(f'type_cnt: {len(color_to_gray_map)}')
    
        model = MyResNet(len(color_to_gray_map)).to(device)
        train_dataset = CustomDataset(image_path, mask_path, image_transform=train_image_transforms, mask_transform=train_mask_transforms, num_samples=augmentation_num)
        train_dataloader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
    
        val_dataset = CustomDataset(image_path, mask_path, image_transform=val_image_transforms, mask_transform=val_mask_transforms, num_samples=1)
        val_dataloader = DataLoader(val_dataset, batch_size=1, shuffle=False)
    
        opt = torch.optim.Adam(model.parameters(), lr=learning_rate)
        sch = torch.optim.lr_scheduler.StepLR(opt, step_size=step_size, gamma=gamma)
        best_accuracy = 0
        best_loss = float('inf')
        patience = 10  # Patience is the number of epochs with no improvement after which training will be stopped.
        patience_counter = 0  # Counter to keep track of the number of epochs with no improvement.
    
        for i in range(train_epoch):
            # print('epoch:', i)
            model.train()
            for image, mask in train_dataloader:
                mask = (mask * 255).long()
                input = torch.cat((image, mask), dim=1).to(device)
                output_mask = torch.argmax(model(input), dim=1)
                output = model(input)
                loss = F.cross_entropy(output, input[:, 3, :, :].long())
                accuracy = torch.sum(output_mask == torch.tensor(input[:, 3, :, :]).to(device)).item() / (224 * 224) / batch_size
                # print(f'Train Loss: {loss.item():.4f} Train Accuracy: {accuracy:.4f}')
    
                opt.zero_grad()
                loss.backward()
                opt.step()
                sch.step()
    
            model.eval()
            for image, mask in val_dataloader:
                
                mask = (mask * 255).long()
                input = torch.cat((image, mask), dim=1).to(device)
                output_mask = torch.argmax(model(input), dim=1)
                output = model(input)
                loss = F.cross_entropy(output, input[:, 3, :, :].long())
                accuracy = torch.sum(output_mask == torch.tensor(input[:, 3, :, :]).to(device)).item() / (224 * 224)
                # print(f'Eval Loss: {loss.item():.4f} Eval Accuracy: {accuracy:.4f}')
                # logging.info(f'Eval Loss: {loss.item():.4f} Eval Accuracy: {accuracy:.4f}')
            if best_loss - loss.item() > 0.001:
                best_accuracy = accuracy
                best_loss = loss.item()
                patience_counter = 0  # Reset the patience counter
            else:
                patience_counter += 1  # Increment the patience counter
            
            if patience_counter >= patience:
                print(f"Early stopping initiated at epoch {i}")
                logging.info(f"Early stopping initiated at epoch {i}")
                if not os.path.exists(models_root):
                    os.makedirs(models_root)
                torch.save(model.state_dict(), model_save_path)
                break  # Break the inner loop and continue with the next target
                
                # mask_np = mask.squeeze(0).cpu().numpy().transpose(1, 2, 0)
                # output_mask_np = output_mask.cpu().numpy().transpose(1, 2, 0)
                # image_np = image.squeeze(0).numpy().transpose(1, 2, 0)
                # print_images([image_np, mask_np, np.where(output_mask_np == mask_np, 1, 0)])
    
        # Print the best evaluation metrics after all epochs
        # print(f'Best Eval Loss: {best_loss:.4f} Best Eval Accuracy: {best_accuracy:.4f}')
        logging.info(f'Best Eval Loss: {best_loss:.4f} Best Eval Accuracy: {best_accuracy:.4f}')
        all_best_accuracies.append(best_accuracy)
        if not os.path.exists(models_root):
            os.makedirs(models_root)
        torch.save(model.state_dict(), model_save_path)
    # 计算所有目标的最佳精确度的平均值
    average_best_accuracy = sum(all_best_accuracies) / len(all_best_accuracies)
    logging.info(f'Average Best Accuracy: {average_best_accuracy:.4f}')  # 添加这行来记录平均最佳精确度
    # return average_best_accuracy


In [34]:
# import optuna

# def objective(trial):
#     lr = trial.suggest_loguniform('lr', 1e-4, 5e-4)
#     batch_size = trial.suggest_categorical('batch_size', [32, 48])
#     step_size = trial.suggest_int('step_size', 10, 30)
#     gamma = trial.suggest_uniform('gamma', 0.6, 0.9)
#     train_epoch = trial.suggest_int('train_epoch', 80, 120)

#     train_list = ['boxing-fisheye']
#     best_accuracy = train(trainval_image_root, trainval_mask_root, train_list, models_root, lr, batch_size, step_size, gamma, train_epoch)
#     return best_accuracy
# study = optuna.create_study(direction='maximize')
# study.optimize(objective, n_trials=10)
# # Output best trial info
# logging.info("Optimization finished. Best trial:")
# trial = study.best_trial
# logging.info(f"Value (Best Loss): {trial.value}")
# for key, value in trial.params.items():
#     logging.info(f"{key}: {value}")
# # Ensure all logs are written to the file
# logging.shutdown()


In [35]:
train(trainval_image_root, trainval_mask_root, train_list)
train(trainval_image_root, trainval_mask_root, val_list)
train(testd_image_root, testd_mask_root, test_list)

bear
type_cnt: 2


  accuracy = torch.sum(output_mask == torch.tensor(input[:, 3, :, :]).to(device)).item() / (224 * 224) / batch_size
  accuracy = torch.sum(output_mask == torch.tensor(input[:, 3, :, :]).to(device)).item() / (224 * 224)


Early stopping initiated at epoch 25
bmx-bumps
type_cnt: 3
Early stopping initiated at epoch 31
boat
type_cnt: 2
Early stopping initiated at epoch 27
boxing-fisheye
type_cnt: 4
Early stopping initiated at epoch 39
breakdance-flare
type_cnt: 2
Early stopping initiated at epoch 28
bus
type_cnt: 2
Early stopping initiated at epoch 29
car-turn
type_cnt: 2
Early stopping initiated at epoch 26
cat-girl
type_cnt: 3
Early stopping initiated at epoch 29
classic-car
type_cnt: 4
Early stopping initiated at epoch 46
color-run
type_cnt: 4
Early stopping initiated at epoch 41
crossing
type_cnt: 4
Early stopping initiated at epoch 37
dance-jump
type_cnt: 2
Early stopping initiated at epoch 32
dancing
type_cnt: 4
Early stopping initiated at epoch 44
disc-jockey
type_cnt: 4
Early stopping initiated at epoch 40
dog-agility
type_cnt: 2
Early stopping initiated at epoch 33
dog-gooses
type_cnt: 6
Early stopping initiated at epoch 34
dogs-scale
type_cnt: 5
Early stopping initiated at epoch 53
drift-turn
typ

In [36]:
import subprocess
import time
import logging

# 你的其他代码，例如模型训练等
# ...

# 训练完成后，等待5秒
time.sleep(30)

# 然后调用Windows命令让系统睡眠
subprocess.run(["rundll32.exe", "powrprof.dll,SetSuspendState", "0,1,0"], check=True)


CompletedProcess(args=['rundll32.exe', 'powrprof.dll,SetSuspendState', '0,1,0'], returncode=0)