# Read and prepare for the data

In [None]:
# some global import
import torch
import torch.nn as nn
from torch.utils.data import DataLoader, Dataset
import numpy as np
import os

import tqdm

In [None]:
# some global variables
train_data_path = '../input/ml2021spring-hw2/timit_11/timit_11/train_11.npy'
train_label_path = '../input/ml2021spring-hw2/timit_11/timit_11/train_label_11.npy'
test_data_path = '../input/ml2021spring-hw2/timit_11/timit_11/test_11.npy'

batch_size = 2048
lr = 0.0001 
epochs = 30
val_percent = 0.2
n_classes = 39
weight_decay = 0.0001 
hidden_dim = 1024
output_dim = n_classes
hidden_num = 12

need_ckpt = True

## Load data

In [None]:
train_data = np.load(train_data_path)
train_label = np.load(train_label_path)

In [None]:
print("train data size: ", train_data.shape)
print("train label size: ", train_label.shape)

n_train_tot = train_data.shape[0]
n_val = int(n_train_tot * val_percent)
n_train = n_train_tot - n_val

train_x = train_data[:n_train]
train_y = train_label[:n_train]
val_x = train_data[n_train:]
val_y = train_label[n_train:]
# print("train set size: ", train_x.shape, train_y.shape)
# print("validate set size: ", val_x.shape, val_y.shape)

# Define the dataset

In [None]:
class TimitDataset(nn.Module):
    def __init__(self, X, Y = None):
        super().__init__()
        self.X = torch.from_numpy(X)
        if Y is not None:
            self.Y = torch.LongTensor(Y.astype(np.int64))
        else:
            self.Y = None
    
    def __getitem__(self, index):
        if self.Y is None:
            return self.X[index]
        else:
            return self.X[index], self.Y[index]
        
    def __len__(self):
        return self.X.shape[0]

In [None]:
train_set = TimitDataset(train_x, train_y) 
val_set = TimitDataset(val_x, val_y)

print("train set size is:", len(train_set))
print("validate set size is:", len(val_set))

# Do some GC

In [None]:
import gc

del train_data, train_label, train_x, train_y, val_x, val_y
gc.collect()

# Define the model

In [None]:
class BasicBlock(nn.Module):
    def __init__(self, in_dim, out_dim):
        super().__init__()
        self.layers = nn.Sequential(
            nn.Linear(in_dim, out_dim),
            nn.BatchNorm1d(out_dim),
            nn.ReLU(),
            nn.Dropout()
        )
        
    def forward(self, x):
        return self.layers(x)

class TimitModel(nn.Module):
    def __init__(self, out_dim: int):
        super().__init__()
        
        self.lstm = nn.GRU(39, 256, 2, batch_first=True, dropout=0.25)
        self.out = nn.Sequential(
            BasicBlock(11*256, 1024),
            nn.Linear(1024, out_dim),
            nn.Softmax(dim=1),
        )
        
    def forward(self, x):
        x = x.view(-1, 11, 39)
        x, _ = self.lstm(x)
        x = x.contiguous().view(x.size(0), -1)
        x = self.out(x)
        return x

# Training

In [None]:
def save_ckpt(model: nn.Module, epoch:int, loss:float):
    print('save ckpt, epoch = {}, loss = {}'.format(epoch, loss))
    torch.save(model.state_dict(), 'ckpt_epoch_{}_loss_{}.pth'.format(epoch, loss))

# Fix random seeds for reproducibility

In [None]:
# fix random seed
def same_seeds(seed):
    torch.manual_seed(seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed(seed)
        torch.cuda.manual_seed_all(seed)  
    np.random.seed(seed)  
    torch.backends.cudnn.benchmark = False
    torch.backends.cudnn.deterministic = True
    
same_seeds(0)

In [None]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = TimitModel(output_dim)
print(model)

model.to(device)

# load pth
if os.path.exists("best_model.pth"):
    model.load_state_dict(torch.load('best_model.pth', map_location=device))

opt = torch.optim.Adam(model.parameters(), lr=lr, weight_decay=weight_decay)
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(opt, int(1e10), eta_min=1e-5)
criterion = nn.CrossEntropyLoss()

train_loader = DataLoader(dataset=train_set, batch_size=batch_size, shuffle=True)
val_loader = DataLoader(dataset=val_set, batch_size=batch_size, shuffle=False)

# read loss
best_loss = float('inf')
if os.path.exists('best_loss.txt'):
    with open('best_loss.txt', mode='r') as f:
        best_loss = float(f.readline())
        print("read best loss:", best_loss)

for epoch in range(epochs):
    model.train()
    if epoch % 50 == 0:
        print('trainning...epoch = %d' % epoch)
    train_tot_loss = 0
    train_tot_acc = 0
    for x, y in train_loader:
        opt.zero_grad()

        x = x.to(device=device, dtype=torch.float32)
        y_true = y.to(device=device)
        y_pred = model(x)
        # compute loss
        loss = criterion(y_pred, y_true)

        train_tot_loss += loss.item()
        _, pred_classes = torch.max(y_pred, 1)
        train_tot_acc += (pred_classes.cpu() == y_true.cpu()).sum().item()
        # update
        loss.backward()
        opt.step()
        scheduler.step()
              
    print('Avg Loss/train:', train_tot_loss / len(train_loader))
    print('Avg Acc/train:', train_tot_acc / len(train_set))
        
    if len(val_loader) > 0:
        val_tot_loss = 0
        with torch.no_grad():
            for x, y in val_loader:
                x = x.to(device=device, dtype=torch.float32)
                y_true = y.to(device=device)
                y_pred = model(x)
                val_tot_loss += criterion(y_pred, y_true).item()
                
        val_loss = val_tot_loss / len(val_loader)        
        if val_loss < best_loss:
            print('Avg Loss/Validate:', val_loss)
            with open('best_loss.txt', mode='w') as f:
                f.write(str(val_loss))
            best_loss = val_loss
            torch.save(model.state_dict(), 'best_model.pth')
    
    if need_ckpt and epoch % 5 == 0:
        save_ckpt(model, epoch, train_tot_loss / len(train_loader))
    

In [None]:
test_model = TimitModel(output_dim)
test_model.load_state_dict(torch.load('best_model.pth', map_location=device))
test_model.to(device)
test_model.eval()

test_data = np.load(test_data_path)
print("test data size:", test_data.shape)
test_set = TimitDataset(test_data)
test_loader = DataLoader(dataset=test_set, batch_size=1, shuffle=False)

# Testing

In [None]:
classes = []
with torch.no_grad():
    for x in test_loader:
        x = x.to(device=device, dtype=torch.float32)
        y_pred = test_model(x)
        _, pred_classes = torch.max(y_pred, 1)
        classes.append(pred_classes[0].item())

In [None]:
import pandas as pd
df = pd.DataFrame({'Class': classes})
df.index.name = 'Id'
df.to_csv('submission.csv')