In [None]:
import torch
from my_data import prepare_input, CustomDataset
from torch.utils.data import DataLoader
from torch.optim import lr_scheduler
from my_model import ResidualBlock, ResViT_PS
from torch.utils.tensorboard import SummaryWriter
from tqdm import tqdm
from sklearn.model_selection import train_test_split

In [None]:
USE_CUDA = torch.cuda.is_available()
device = torch.device("cuda:0" if USE_CUDA else "cpu")
print(device)

In [None]:
df = open('./29_Training Dataset/Training Dataset/play_style_train.csv').read().splitlines()
games = [i.split(',',2)[-1] for i in df]
game_styles = [int(i.split(',',2)[-2]) for i in df]

In [None]:
len(game_styles)

In [None]:
y_list = torch.tensor(game_styles)-1
y_list = torch.nn.functional.one_hot(y_list, 3)
y_list = torch.tensor(y_list, dtype = torch.float32)

In [None]:
y_list.shape

In [None]:
n_games = 0
for game in games:
    n_games += 1
print(f"Total Games: {n_games}")

In [None]:
pre_data_num = 7
model = ResViT_PS(input_dim=5 + 2* pre_data_num, ResidualBlock=ResidualBlock)
model = model.to(device)

In [None]:
loss_fn = torch.nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr = 1e-4)
scheduler = lr_scheduler.StepLR(optimizer,step_size=1, gamma=0.9)
epochs = 8
bs = 64

In [None]:
writer = SummaryWriter(log_dir='./logs/log_resvit_playstyle')

In [None]:
train_games, val_games, train_y, val_y = train_test_split(games, y_list, test_size=0.004,random_state=64)

In [None]:
len(train_games), len(val_games)

In [None]:
len(train_y), len(val_y)

In [None]:
val_x_list = []
val_y_list = []
x = []


for val_game_num, val_game in enumerate(val_games):
    val_move_list = val_game.split(',')
    val_x = prepare_input(val_move_list)
    last_idx = len(val_move_list)-1
    if(last_idx < pre_data_num):
        for i in range(last_idx):
            val_x = torch.cat((val_x, prepare_input(val_move_list[:last_idx-1-i])[:2]), dim=0)
        for i in range(pre_data_num - last_idx):
            val_x = torch.cat((val_x, torch.zeros((2, 19, 19))), dim=0)      
    else:
        for i in range(pre_data_num):
            val_x = torch.cat((val_x, prepare_input(val_move_list[:last_idx-1-i])[:2]), dim=0)
    
    val_x_list.append(val_x)

val_x_list = torch.stack(val_x_list)

# val_y_stack = torch.stack(val_y_list)
val_x_list = val_x_list.to(device)
val_y = val_y.to(device)

val_dataset = CustomDataset(val_x_list, val_y)
data_loader_val = DataLoader(dataset=val_dataset, batch_size=len(val_games), shuffle=False)

In [None]:
from torchvision import transforms 
train_x_list = []
train_y_list = []
x = []


for train_game_num, train_game in enumerate(train_games):
    train_move_list = train_game.split(',')
    train_x = prepare_input(train_move_list)
    last_idx = len(train_move_list)-1
    if(last_idx < pre_data_num):
        for i in range(last_idx):
            train_x = torch.cat((train_x, prepare_input(train_move_list[:last_idx-1-i])[:2]), dim=0)
        for i in range(pre_data_num - last_idx):
            train_x = torch.cat((train_x, torch.zeros((2, 19, 19))), dim=0)      
    else:
        for i in range(pre_data_num):
            train_x = torch.cat((train_x, prepare_input(train_move_list[:last_idx-1-i])[:2]), dim=0)
    train_x_list.append(train_x)

train_x_list = torch.stack(train_x_list)
transform = transforms.RandomHorizontalFlip(p=1.0)
h_train_x_list = transform(train_x_list)

transform = transforms.RandomVerticalFlip(p=1.0)
v_train_x_list = transform(train_x_list)

transform = transforms.RandomRotation(90)
r_train_x_list = transform(train_x_list)

transform = transforms.RandomRotation(180)
r1_train_x_list = transform(train_x_list)

transform = transforms.RandomRotation(270)
r2_train_x_list = transform(train_x_list)



train_x_list = torch.cat((train_x_list, h_train_x_list, v_train_x_list,r_train_x_list,r1_train_x_list,r2_train_x_list), dim=0)
print(train_x_list.shape)
train_x_list = train_x_list.to(device)
train_y = torch.cat((train_y, train_y, train_y, train_y, train_y,train_y), dim=0)
print(train_y.shape)
train_y = train_y.to(device)
train_dataset = CustomDataset(train_x_list, train_y)
data_loader_train = DataLoader(dataset=train_dataset, batch_size=bs, shuffle=True)

In [None]:
train_x_list [0].shape

In [None]:
len(val_dataset)

In [None]:
total_games = 0 
best_val_loss = float('inf')


for epoch in range(epochs):
    scheduler.step()
    for x,y in tqdm(data_loader_train):
        outputs = model(x) 
        loss = loss_fn(outputs, y)
        # print(outputs)
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        avg_loss = loss.item()  
        total_games += 1 
        predicted_labels = torch.argmax(outputs, dim=1)
        true_labels = torch.argmax(y, dim=1)  
        accuracy = torch.sum(predicted_labels == true_labels).item() / len(true_labels)
        
        writer.add_scalar('accuracy/train_game', accuracy, total_games)
        # print(f'Train Acc: {accuracy}')
        writer.add_scalar('loss/train_game', avg_loss, total_games)
   
        if((total_games)%2000 == 0):
            torch.save(model, f'./playstyle_models/model{total_games}.pth')
  
        if(total_games%1 == 0 or total_games==1):
            model.eval()  
            total_val_loss = 0
            total_val_moves = 0 

            with torch.no_grad(): 
                for val_x, val_y in tqdm(data_loader_val):
                    val_outputs = model(val_x)
                    val_loss = loss_fn(val_outputs, val_y)
                    total_val_loss += val_loss.item() * val_x.shape[0]
                    val_predicted_labels = torch.argmax(val_outputs, dim=1) 
                    val_true_labels = torch.argmax(val_y, dim=1)
                    avg_loss_val = total_val_loss/len(val_dataset)
                    val_accuracy = torch.sum(val_predicted_labels == val_true_labels).item() / len(val_true_labels)
                    writer.add_scalar('loss/val_game', avg_loss_val, total_games)
                    writer.add_scalar('accuracy/val_game ', val_accuracy, total_games)
                    # print(f'Validation Acc: {val_accuracy}')
                    # print(f'Validation Loss: {avg_loss_val}')
                    
                    if(total_games<=12000 and avg_loss_val<best_val_loss):
                        best_val_loss = avg_loss_val
                        torch.save(model,'./playstyle_models/best_PS_model.pth')
                
            model.train()  