In [None]:
import torch
from my_data import prepare_input, prepare_label, CustomDataset
from torch.utils.data import DataLoader
from my_model import ResidualBlock, ResViT_kyu
from torch.utils.tensorboard import SummaryWriter
from tqdm import tqdm
from sklearn.model_selection import train_test_split
import os

In [None]:
try:
    os.makedirs('./kyu_models')
    print(f"建立資料夾成功")
except FileExistsError:
    print(f"已經建立資料夾")

In [None]:
USE_CUDA = torch.cuda.is_available()
device = torch.device("cuda:0" if USE_CUDA else "cpu")
print(device)

In [None]:
df = open('./29_Training Dataset/Training Dataset/kyu_train.csv').read().splitlines()
games = [i.split(',',2)[-1] for i in df] #把前面和下棋無關的字串分隔開，並且只取後面有關的字串

In [None]:
pre_data_num = 0
model = ResViT_kyu(input_dim=5 + 2 * pre_data_num, ResidualBlock=ResidualBlock)
model = model.to(device)

In [None]:
loss_fn = torch.nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr = 1e-3)
epochs = 2
bs = 512

In [None]:
writer = SummaryWriter(log_dir='./logs/log_resvit_kyu')

In [None]:
train_games, val_games = train_test_split(games, test_size=0.001)

In [None]:
len(train_games), len(val_games)

In [None]:
val_x_list = []
val_y_list = []
val_x_pre_eight_list = []
for val_game_num, val_game in enumerate(val_games):
    val_move_list = val_game.split(',')

    for idx, val_move in enumerate(val_move_list):
        val_x = prepare_input(val_move_list[:idx])
        val_x_list.append(val_x)
        val_y = prepare_label(val_move_list[idx])
        val_y = torch.nn.functional.one_hot(val_y, 19*19)
        val_y = torch.tensor(val_y, dtype=torch.float32)
        val_y_list.append(val_y)

        if((idx+1) <= pre_data_num):
            for i in range(idx):
                val_x = torch.cat((val_x, val_x_list[idx-1-i][5]), dim=0)
            for i in range(pre_data_num - idx):
                val_x = torch.cat((val_x, torch.zeros((1, 19, 19))), dim=0)
               
        elif((idx+1) > pre_data_num):
            for i in range(pre_data_num):
                val_x = torch.cat((val_x, val_x_list[idx-1-i][5]), dim=0)
            
        val_x_pre_eight_list.append(val_x)


val_x_pre_eight_stack = torch.stack(val_x_pre_eight_list)
val_y_stack = torch.stack(val_y_list)

val_x_pre_eight_stack = val_x_pre_eight_stack.to(device)
val_y_stack = val_y_stack.to(device)

val_dataset = CustomDataset(val_x_pre_eight_stack, val_y_stack)
data_loader_val = DataLoader(dataset=val_dataset, batch_size=bs, shuffle=False)

In [None]:
val_x_pre_eight_list[0].shape

In [None]:
len(val_dataset)

In [None]:
total_games = 0 #值為1代表模型看過100場games，以此類推
for epoch in range(epochs):
    train_x_pre_eight_list = []
    train_y_list = []
    for game_num, game in enumerate(train_games):
        total_loss_game = 0
        accuracy_top1 = []
        accuracy_top5 = []
        move_list = game.split(',')
        train_x_list = []
        
        for idx, move in enumerate(move_list):
            train_x = prepare_input(move_list[:idx])
            train_x_list.append(train_x)
            train_y = prepare_label(move_list[idx])
            train_y = torch.nn.functional.one_hot(train_y, 19*19) #y裡面代表了move的位置(0~361)，把它轉為19*19的array並且該位置為1
            train_y = torch.tensor(train_y, dtype=torch.float32)
            train_y_list.append(train_y)
            
            if((idx+1) <= pre_data_num):
                for i in range(idx):
                    train_x = torch.cat((train_x, train_x_list[idx-1-i][:2]), dim=0)
                for i in range(pre_data_num - idx):
                    train_x = torch.cat((train_x, torch.zeros((2, 19, 19))), dim=0)
               
            elif((idx+1) > pre_data_num):
                for i in range(pre_data_num):
                    train_x = torch.cat((train_x, train_x_list[idx-1-i][:2]), dim=0)
            
            train_x_pre_eight_list.append(train_x)
            
            
        if((game_num+1)%100 != 0):    
            continue #以每100場games為單位，讀入記憶體中做訓練


        train_x_pre_eight_stack = torch.stack(train_x_pre_eight_list) #當前加上前七場game的訓練資料x
        train_y_stack = torch.stack(train_y_list) #一場game的訓練資料y
    
        train_x_pre_eight_stack = train_x_pre_eight_stack.to(device)
        train_y_stack = train_y_stack.to(device)

        train_dataset = CustomDataset(train_x_pre_eight_stack, train_y_stack)
        data_loader_train = DataLoader(dataset=train_dataset, batch_size=bs, shuffle=True)
 
        train_x_pre_eight_list = []
        train_y_list = []
        for x,y in tqdm(data_loader_train):
            outputs = model(x) 
            top_k_values, top_k_indices = torch.topk(outputs, k=5, dim=1)
            y_values, y_indices = torch.topk(y, k=1)
            correct_predictions_top1 = torch.eq(top_k_indices[:, 0], y_indices.squeeze())
            correct_predictions_top5 = torch.any(torch.eq(top_k_indices, y_indices), dim=1)
        
            accuracy_top1.append(torch.sum(correct_predictions_top1  == True) / len(y))
            accuracy_top5.append(torch.sum(correct_predictions_top5  == True) / len(y))
            
            loss = loss_fn(outputs, y)
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            total_loss_game += loss.item() * x.shape[0]

        total_games += 1  
        writer.add_scalar('accuracy/top1', sum(accuracy_top1)/len(accuracy_top1), total_games) 
        writer.add_scalar('accuracy/top5', sum(accuracy_top5)/len(accuracy_top5), total_games) 
        # print(f'Accuracy_top1: {(sum(accuracy_top1)/len(accuracy_top1)) * 100:.2f}%')
        # print(f'Accuracy_top5: {(sum(accuracy_top5)/len(accuracy_top5)) * 100:.2f}%')


        avg_loss_games = total_loss_game/len(train_dataset) #計算每一場遊戲中的loss
        writer.add_scalar('loss/train_games', avg_loss_games, total_games) 
        # print(f'avg_loss_games:{avg_loss_games}')
      
        if(total_games%100 == 0):
            torch.save(model, f'./kyu_models/model{total_games}.pth') #total_games為1代表模型看過100場games，所以是每10000場game儲存一次模型
            
        if(total_games%100 == 0 or total_games==1):
            #一段區間的訓練後再對驗證集資料進行評估
            model.eval()  
            total_val_loss = 0
            total_val_moves = 0 
            val_accuracy_top1 = []
            val_accuracy_top5 = []
            with torch.no_grad():
                for val_x, val_y in tqdm(data_loader_val):
                    val_outputs = model(val_x)
                    val_loss = loss_fn(val_outputs, val_y)
                    total_val_loss += val_loss.item() * val_x.shape[0]

                    top_k_values, top_k_indices = torch.topk(val_outputs, k=5, dim=1)
                    y_values, y_indices = torch.topk(val_y, k=1)
                    correct_predictions_top1 = torch.eq(top_k_indices[:, 0], y_indices.squeeze())
                    correct_predictions_top5 = torch.any(torch.eq(top_k_indices, y_indices), dim=1)
                
                    val_accuracy_top1.append(torch.sum(correct_predictions_top1  == True) / len(val_y))
                    val_accuracy_top5.append(torch.sum(correct_predictions_top5  == True) / len(val_y))

            avg_loss_val = total_val_loss/len(val_dataset) 
            writer.add_scalar('loss/val_games', avg_loss_val, total_games)
            writer.add_scalar('accuracy/val_top1', sum(val_accuracy_top1)/len(val_accuracy_top1), total_games) 
            writer.add_scalar('accuracy/val_top5', sum(val_accuracy_top5)/len(val_accuracy_top5), total_games) 
            # print(f'Validation Loss: {avg_loss_val}')
            # print(f'Val Accuracy_top1: {(sum(val_accuracy_top1)/len(val_accuracy_top1)) * 100:.2f}%')
            # print(f'Val Accuracy_top5: {(sum(val_accuracy_top5)/len(val_accuracy_top5)) * 100:.2f}%')
            model.train() 