# TRAIN 

In [72]:
import os
import numpy as np
import time
import matplotlib.pyplot as plt
from tqdm import tqdm
from easydict import EasyDict
import copy
import json
from sklearn.metrics import confusion_matrix


import torch
import torch.nn as nn
import torch.optim as optim
import torchvision
from torchvision import models
from torch.utils.data import DataLoader,random_split
from torch.utils.data import Subset

from utils import Dataset
from utils import utils

import models 

import sklearn

In [74]:
#model in [Times_AcT , AcT , Densenet]
model_name='AcT'

args=EasyDict({
    'base_model': model_name,
    'pretrained':True,
    'lr':0.05,
    'start_epoch':1,
    'num_epochs':25,
    'continue_epoch':False,
    
    #Dataset parms
    'num_classes':10,
    'batch_size':16,
    
    #Path params
    'annotation_path':'../dataset/annotation_dict.json',
    'augmented_annotation_path':'../dataset/augmented_annotation_dict.json',
    'model_path':'./models/weights/'+model_name+'/',
    'history_path':'./history/'
})

In [75]:
def Train_model(model,dataloaders_dict,criterion,optimizer,args):
    start_time=time.time()
    
    train_acc_history=[]
    val_acc_history=[]
    train_loss_history=[]
    val_loss_history=[]
    train_f1_history=[]
    val_f1_history=[]
    plot_epoch=[]
    
    best_model_wts=copy.deepcopy(model.state_dict())
    best_acc=0.0
    best_epoch=0
    model.to(device)
    num_epoch=args.num
    for epoch in range(1,num_epoch+1):
        for phase in ['train','val']:
            epoch_time=time.time()
            if phase == 'train':
                model.train()
                train_pred_class=[]
                train_ground_truths=[]
            else:
                model.eval()
                val_pred_class=[]
                val_ground_truths=[]
            running_loss=0.0
            running_corrects=0
            train_n_total=1
            
            pbar=tqdm(dataloaders_dict[phase])
            i=0
            for sample in pbar:
                #動画のみなら 'video'  joints なら 'joints' 
                #inputs=sample['video'].to(device)
                inputs=sample['joints'].to(device)
                
                labels=sample['action'].to(device)
                
                optimizer.zero_grad()
                
                with torch.set_grad_enabled(phase=='train'):
                    outputs=model(inputs)
                    loss=criterion(outputs,torch.max(labels,1)[1])
                    
                    _,preds=torch.max(outputs,1)
                    
                    if phase=='train':
                        train_pred_class.extend(preds.detach().cpu().numpy())
                        train_ground_truths.extend(torch.max(labels,1)[1].detach().cpu().numpy())
                    else:
                        val_pred_class.extend(preds.detach().cpu().numpy())
                        val_ground_truths.extend(torch.max(labels,1)[1].detach().cpu().numpy())
                    if phase=='train':
                        loss.backward()
                        optimizer.step()
                    
                running_loss+=loss.item()*inputs.size(0)
                running_corrects+=torch.sum(preds==torch.max(labels,1)[1])
                
                pbar.set_description('Phase: {} || Epoch: {} || Loss{:.5f}'.format(phase,epoch,running_loss/train_n_total))
                train_n_total+=1
            
            epoch_loss=running_loss/len(dataloaders_dict[phase].dataset)
            epoch_acc=running_corrects/len(dataloaders_dict[phase].dataset)
            
            #Train and Val have done
            epoch_end_time=time.time()-epoch_time
            if phase=='train':
                print('Training Complete in {:.0f}m {:.0f}s '.format(epoch_end_time//60,epoch_end_time%60))
            else:
                print('Validation Complete in {:.0f}m {:.0f}s '.format(epoch_end_time//60,epoch_end_time%60))
            #print('{} Loss: {:.4f} Acc:{:.4f} % '.format(phase,epoch_loss,epoch_acc))
            
            if phase =='train':
                train_acc_history.append(epoch_acc)
                train_loss_history.append(epoch_loss)
                train_pred_classes=np.asarray(train_pred_class)
                train_ground_truths=np.asarray(train_ground_truths)
                
                train_accuracy,train_precision,train_recall,train_f1=utils.Get_scores(
                    train_pred_classes , train_ground_truths)
                
                train_f1_history.append(train_f1)
                train_confusion_matrix=np.array_str(confusion_matrix(train_ground_truths,train_pred_class,labels=[0, 1, 2, 3, 4, 5, 6, 7, 8, 9]))
                print('Epoch: {} || Train_Acc: {} || Train_Loss: {}'.format(
                    epoch, train_accuracy, epoch_loss
                ))
                print(f'train: \n{train_confusion_matrix}')
                plot_epoch.append(epoch)
                
                train_loss=epoch_loss
            
            # For Checkpointing and Confusion Matrix
            if phase == 'val':
                val_acc_history.append(epoch_acc)
                val_loss_history.append(epoch_loss)
                val_pred_classes = np.asarray(val_pred_class)
                val_ground_truths = np.asarray(val_ground_truths)
                
                val_accuracy,val_precision,val_recall,val_f1 = utils.Get_scores(
                    val_pred_classes, val_ground_truths
                )
                
                val_f1_history.append(val_f1)
                val_confusion_matrix = np.array_str(confusion_matrix(val_ground_truths, val_pred_classes, labels=[0, 1, 2, 3, 4, 5, 6, 7, 8, 9]))
                print('Epoch: {} || Val_Acc: {} || Val_Loss: {}'.format(
                    epoch, val_accuracy, epoch_loss
                ))
                print(f'val: \n{val_confusion_matrix}')

                # Deep Copy Model if best accuracy
                if epoch_acc > best_acc:
                    best_acc = epoch_acc
                    best_model_wts = copy.deepcopy(model.state_dict())
                    best_epoch=epoch+1
                    
                # set current loss to val loss for write history
                val_loss = epoch_loss
            
            
        #モデル/optimizerを保存
        model_name= utils.save_weights(model, args, epoch, optimizer)
            
        #writehistory
        # Write History after train and validation phase
        utils.write_history(
                    args,
                    epoch,
                    train_loss,train_accuracy,train_f1,train_precision,train_recall,train_confusion_matrix,
                    val_loss,val_accuracy,val_f1,val_precision,val_recall,val_confusion_matrix,
                    best_acc)
        
        
    end=time.time()-start_time
    print('all done in {:.0f}m {:.0f}s'.format(end//60,end%60))
    print('Best val Acc {:.4f}'.format(best_acc))
        
    #load best model
    model.load_state_dict(best_model_wts)
    save_path=os.path.join(args.model_path,'best_weights_{}.pth'.format(best_epoch))
    torch.save(model.state_dict(),save_path)
    return model, train_loss_history, val_loss_history, train_acc_history, val_acc_history, train_f1, val_f1, plot_epoch


### Dataset Dataloaderの作成....

In [54]:
if args.base_model in ['Times_AcT' , 'AcT']:
    augment=False
    poseData=True
    joints_to_numpy=True
else:
    augment=True
    poseData=False
    joints_to_numpy=False
    
dataset=Dataset.BasketballDataset(annotation_dict='./dataset/annotation_dict.json',
                          augmented_dict='./dataset/augmented_annotation_dict.json',
                         augment=augment,poseData=poseData,joints_to_numpy=joints_to_numpy)


train_dataset_size=dataset.__len__()
train_num=int(train_dataset_size*0.7)
val_num=train_dataset_size-train_num

### ver 2.
train_dataset,val_dataset=torch.utils.data.random_split(dataset,[train_num,val_num])

#test用のsubset
train_subset=Subset(train_dataset,list(range(0,7000)))
val_subset=Subset(val_dataset,list(range(0,3000)))

#train_loader=DataLoader(train_subset,shuffle=True,batch_size=4)
#val_loader=DataLoader(val_subset,shuffle=False,batch_size=4)

train_loader=DataLoader(train_dataset,shuffle=True,batch_size=args.batch_size)
val_loader=DataLoader(val_dataset,shuffle=False,batch_size=args.batch_size)

dataloaders_dict={'train':train_loader,'val':val_loader}



In [None]:
class Focal_Loss(nn.Module):
    def __init__(self,gamma):
        super().__init__()
        self.gamma=gamma
        self.bceloss=nn.BCELoss(reduction='none')
    
    def forward(self,out,tar):
        bce=self.bceloss(out,tar)
        bce_exp=torch.exp(-bce)
        forcal_loss=(1-bce_exp)**self.gamma*bce
        return forcal_loss.mean()

In [71]:

model=models.Times_AcT.generate_model(128)

#損失関数の重み付けを考える。
#まずはデータ数の逆比で [87.05399061 52.08567416 37.23393574 34.6588785  15.70067739 10.6260745 9.59260217  6.26012829  5.71417565  3.15643885]
#weights=torch.tensor([87.0, 52.1, 37.2 ,34.7 , 15.7 ,10.6,9.59,  6.26,  5.71 , 3.16]).to('cuda:0')

optimizer=optim.Adam(model.parameters(),lr=args.lr)
criterion=nn.CrossEntropyLoss()
device=torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')

# Train and evaluate
model, train_loss_history, val_loss_history,train_acc_history, val_acc_history, train_f1_score, val_f1_score, plot_epoch = Train_model(model,dataloaders_dict,criterion,optimizer,args,num_epoch=25)


Phase: train || Epoch: 1 || Loss32.32643: 100%|████████████████████████████████████████████████████████████████████████| 1623/1623 [01:02<00:00, 25.77it/s]
  _warn_prf(average, modifier, msg_start, len(result))


Training Complete in 1m 3s 
Epoch: 1 || Train_Acc: 0.14480526984860742 || Train_Loss: 2.021102401374361
train: 
[[  23    6  579   78    3    1    0    0    0    2]
 [  22    1  635   84    0    0    0    0    4    2]
 [ 131   29 3420  509    3    7    4    0    9    8]
 [  77   17 2083  281    1    3    4    0    2   10]
 [  11    1  263   24    0    0    0    0    0    1]
 [  55    5 1438  168    1    3    1    0    6    3]
 [  95   17 2231  331    2    2    7    0    4    2]
 [  11    2  410   75    0    1    2    0    0    1]
 [ 131   17 3790  579    1    5    3    0   11    8]
 [ 271   46 6849  988    9    8    9    0   10   13]]


Phase: val || Epoch: 1 || Loss30.75007: 100%|████████████████████████████████████████████████████████████████████████████| 696/696 [00:14<00:00, 48.44it/s]
  _warn_prf(average, modifier, msg_start, len(result))


Validation Complete in 0m 14s 
Epoch: 1 || Val_Acc: 0.16214272874348373 || Val_Loss: 1.9236066297685073
val: 
[[   0    0  304    0    0    0    0    0    0    0]
 [   0    0  322    0    0    0    0    0    0    0]
 [   0    0 1804    0    0    0    0    0    0    0]
 [   0    0 1012    0    0    0    0    0    0    0]
 [   0    0  126    0    0    0    0    0    0    0]
 [   0    0  682    0    0    0    0    0    0    0]
 [   0    0 1175    0    0    0    0    0    0    0]
 [   0    0  210    0    0    0    0    0    0    0]
 [   0    0 1945    0    0    0    0    0    0    0]
 [   0    0 3546    0    0    0    0    0    0    0]]


Phase: train || Epoch: 2 || Loss31.59774: 100%|████████████████████████████████████████████████████████████████████████| 1623/1623 [01:03<00:00, 25.54it/s]
  _warn_prf(average, modifier, msg_start, len(result))


Training Complete in 1m 4s 
Epoch: 2 || Train_Acc: 0.146846950961131 || Train_Loss: 1.9755436035392695
train: 
[[  16    4  608   64    0    0    0    0    0    0]
 [  12    3  655   78    0    0    0    0    0    0]
 [ 125   26 3597  372    0    0    0    0    0    0]
 [  56   13 2213  196    0    0    0    0    0    0]
 [   9    3  263   25    0    0    0    0    0    0]
 [  42    8 1492  138    0    0    0    0    0    0]
 [  58    9 2396  228    0    0    0    0    0    0]
 [   9    3  448   42    0    0    0    0    0    0]
 [ 116   20 4017  392    0    0    0    0    0    0]
 [ 206   44 7230  723    0    0    0    0    0    0]]


Phase: val || Epoch: 2 || Loss31.58879: 100%|████████████████████████████████████████████████████████████████████████████| 696/696 [00:14<00:00, 48.48it/s]
  _warn_prf(average, modifier, msg_start, len(result))


Validation Complete in 0m 14s 
Epoch: 2 || Val_Acc: 0.16214272874348373 || Val_Loss: 1.9760738278429606
val: 
[[   0    0  304    0    0    0    0    0    0    0]
 [   0    0  322    0    0    0    0    0    0    0]
 [   0    0 1804    0    0    0    0    0    0    0]
 [   0    0 1012    0    0    0    0    0    0    0]
 [   0    0  126    0    0    0    0    0    0    0]
 [   0    0  682    0    0    0    0    0    0    0]
 [   0    0 1175    0    0    0    0    0    0    0]
 [   0    0  210    0    0    0    0    0    0    0]
 [   0    0 1945    0    0    0    0    0    0    0]
 [   0    0 3546    0    0    0    0    0    0    0]]


Phase: train || Epoch: 3 || Loss31.80205: 100%|████████████████████████████████████████████████████████████████████████| 1623/1623 [01:03<00:00, 25.59it/s]
  _warn_prf(average, modifier, msg_start, len(result))


Training Complete in 1m 3s 
Epoch: 3 || Train_Acc: 0.14823375322624138 || Train_Loss: 1.9883170345723296
train: 
[[  12    7  590   82    0    0    0    0    1    0]
 [  26    6  636   79    0    0    0    0    1    0]
 [ 116   31 3560  408    0    1    0    0    4    0]
 [  58   21 2134  264    0    1    0    0    0    0]
 [   8    1  264   27    0    0    0    0    0    0]
 [  46   13 1442  177    0    1    0    0    1    0]
 [  71   25 2312  283    0    0    0    0    0    0]
 [  14    5  421   62    0    0    0    0    0    0]
 [ 123   42 3892  482    0    0    1    0    5    0]
 [ 220   77 6999  901    0    1    1    0    4    0]]


Phase: val || Epoch: 3 || Loss31.00762: 100%|████████████████████████████████████████████████████████████████████████████| 696/696 [00:14<00:00, 48.96it/s]
  _warn_prf(average, modifier, msg_start, len(result))


Validation Complete in 0m 14s 
Epoch: 3 || Val_Acc: 0.16214272874348373 || Val_Loss: 1.9397178682817222
val: 
[[   0    0  304    0    0    0    0    0    0    0]
 [   0    0  322    0    0    0    0    0    0    0]
 [   0    0 1804    0    0    0    0    0    0    0]
 [   0    0 1012    0    0    0    0    0    0    0]
 [   0    0  126    0    0    0    0    0    0    0]
 [   0    0  682    0    0    0    0    0    0    0]
 [   0    0 1175    0    0    0    0    0    0    0]
 [   0    0  210    0    0    0    0    0    0    0]
 [   0    0 1945    0    0    0    0    0    0    0]
 [   0    0 3546    0    0    0    0    0    0    0]]


Phase: train || Epoch: 4 || Loss31.82911: 100%|████████████████████████████████████████████████████████████████████████| 1623/1623 [01:03<00:00, 25.73it/s]
  _warn_prf(average, modifier, msg_start, len(result))


Training Complete in 1m 3s 
Epoch: 4 || Train_Acc: 0.14807966408567355 || Train_Loss: 1.9900093422294596
train: 
[[  14    4  587   86    0    0    1    0    0    0]
 [  11    2  649   84    0    0    1    0    1    0]
 [  86   26 3525  478    0    0    5    0    0    0]
 [  46   15 2112  299    0    0    4    0    2    0]
 [   8    0  248   41    0    0    0    0    3    0]
 [  34    7 1473  164    0    0    1    0    1    0]
 [  61   13 2305  307    0    0    2    0    3    0]
 [  10    1  449   42    0    0    0    0    0    0]
 [ 116   23 3846  551    0    0    7    0    2    0]
 [ 191   39 7052  910    0    0    7    0    4    0]]


Phase: val || Epoch: 4 || Loss30.70015: 100%|████████████████████████████████████████████████████████████████████████████| 696/696 [00:14<00:00, 48.93it/s]
  _warn_prf(average, modifier, msg_start, len(result))


Validation Complete in 0m 14s 
Epoch: 4 || Val_Acc: 0.16214272874348373 || Val_Loss: 1.9204841477824226
val: 
[[   0    0  304    0    0    0    0    0    0    0]
 [   0    0  322    0    0    0    0    0    0    0]
 [   0    0 1804    0    0    0    0    0    0    0]
 [   0    0 1012    0    0    0    0    0    0    0]
 [   0    0  126    0    0    0    0    0    0    0]
 [   0    0  682    0    0    0    0    0    0    0]
 [   0    0 1175    0    0    0    0    0    0    0]
 [   0    0  210    0    0    0    0    0    0    0]
 [   0    0 1945    0    0    0    0    0    0    0]
 [   0    0 3546    0    0    0    0    0    0    0]]


Phase: train || Epoch: 5 || Loss31.77880: 100%|████████████████████████████████████████████████████████████████████████| 1623/1623 [01:03<00:00, 25.50it/s]
  _warn_prf(average, modifier, msg_start, len(result))


Training Complete in 1m 4s 
Epoch: 5 || Train_Acc: 0.14800261951538965 || Train_Loss: 1.9868634935474032
train: 
[[  11    5  582   92    0    1    0    0    1    0]
 [  11    4  644   88    0    1    0    0    0    0]
 [  81   20 3552  460    0    3    0    0    4    0]
 [  43    7 2151  273    0    3    0    0    1    0]
 [   6    0  258   36    0    0    0    0    0    0]
 [  34    8 1430  203    0    2    0    0    3    0]
 [  55   12 2318  302    0    2    0    0    2    0]
 [  11    3  420   67    0    1    0    0    0    0]
 [  88   18 3953  483    0    3    0    0    0    0]
 [ 155   34 7056  949    0    4    0    0    5    0]]


Phase: val || Epoch: 5 || Loss30.78407: 100%|████████████████████████████████████████████████████████████████████████████| 696/696 [00:14<00:00, 48.45it/s]
  _warn_prf(average, modifier, msg_start, len(result))


Validation Complete in 0m 14s 
Epoch: 5 || Val_Acc: 0.16214272874348373 || Val_Loss: 1.925733523983418
val: 
[[   0    0  304    0    0    0    0    0    0    0]
 [   0    0  322    0    0    0    0    0    0    0]
 [   0    0 1804    0    0    0    0    0    0    0]
 [   0    0 1012    0    0    0    0    0    0    0]
 [   0    0  126    0    0    0    0    0    0    0]
 [   0    0  682    0    0    0    0    0    0    0]
 [   0    0 1175    0    0    0    0    0    0    0]
 [   0    0  210    0    0    0    0    0    0    0]
 [   0    0 1945    0    0    0    0    0    0    0]
 [   0    0 3546    0    0    0    0    0    0    0]]


Phase: train || Epoch: 6 || Loss31.72633: 100%|████████████████████████████████████████████████████████████████████████| 1623/1623 [01:03<00:00, 25.51it/s]
  _warn_prf(average, modifier, msg_start, len(result))


Training Complete in 1m 4s 
Epoch: 6 || Train_Acc: 0.14734774066797643 || Train_Loss: 1.983583038111048
train: 
[[   9    2  587   94    0    0    0    0    0    0]
 [  13    8  627  100    0    0    0    0    0    0]
 [  68   27 3509  511    0    4    1    0    0    0]
 [  31   20 2128  297    0    2    0    0    0    0]
 [   2    3  254   41    0    0    0    0    0    0]
 [  13    8 1424  233    0    2    0    0    0    0]
 [  31   16 2285  358    0    1    0    0    0    0]
 [   8    2  431   60    0    1    0    0    0    0]
 [  64   40 3860  580    0    0    1    0    0    0]
 [ 105   65 7006 1023    0    4    0    0    0    0]]


Phase: val || Epoch: 6 || Loss31.68619: 100%|████████████████████████████████████████████████████████████████████████████| 696/696 [00:14<00:00, 48.47it/s]
  _warn_prf(average, modifier, msg_start, len(result))


Validation Complete in 0m 14s 
Epoch: 6 || Val_Acc: 0.16214272874348373 || Val_Loss: 1.982166849536131
val: 
[[   0    0  304    0    0    0    0    0    0    0]
 [   0    0  322    0    0    0    0    0    0    0]
 [   0    0 1804    0    0    0    0    0    0    0]
 [   0    0 1012    0    0    0    0    0    0    0]
 [   0    0  126    0    0    0    0    0    0    0]
 [   0    0  682    0    0    0    0    0    0    0]
 [   0    0 1175    0    0    0    0    0    0    0]
 [   0    0  210    0    0    0    0    0    0    0]
 [   0    0 1945    0    0    0    0    0    0    0]
 [   0    0 3546    0    0    0    0    0    0    0]]


Phase: train || Epoch: 7 || Loss32.21737:  19%|██████████████▏                                                          | 316/1623 [00:12<00:51, 25.43it/s]


KeyboardInterrupt: 

In [None]:
def get_action_list(anno_path='./dataset/annotation_dict.json'):
    with open(anno_path) as f:
        annotation_dict=list(json.load(f).items())
    action_list={}
    
    # Defence - "block"
    stop = [val for val in annotation_dict if val[1] == 0]
    # Passing
    ball_pass = [val for val in annotation_dict if val[1] == 1]
    # Race or Running
    race = [val for val in annotation_dict if val[1] == 2]
    # Dribble
    dribble = [val for val in annotation_dict if val[1] == 3]
    # Shooting
    shooting = [val for val in annotation_dict if val[1] == 4]
    # Ball In Hand
    ballinhand = [val for val in annotation_dict if val[1] == 5]
    # Defensive Position
    defence = [val for val in annotation_dict if val[1] == 6]
    # Pick Attempt
    pick_attempt = [val for val in annotation_dict if val[1] == 7]
    # No Action - Just standing
    noaction = [val for val in annotation_dict if val[1] == 8]
    # walk
    walk = [val for val in annotation_dict if val[1] == 9]
    
    action_list={'stop':stop,'ball_pass':ball_pass,'race':race,'dribble':dribble,'shooting':shooting,
               'ballinhand':ballinhand,'defence':defence,'pick_attempt':pick_attempt,'noaction':noaction,'walk':walk}
    return action_list
action_list=get_action_list()

In [59]:
a=np.array([426,712,996,1070,2362,3490,3866,5924,6490,11749])/a.sum()

In [67]:
a_copy=np.copy(a)
print(a)

[0.01148712 0.01919914 0.02685722 0.02885264 0.06369152 0.09410813
 0.104247   0.15974114 0.17500337 0.31681273]


In [68]:
for i in range(10):
    a_copy[i]=1/a[i]

In [69]:
print(a_copy)

[87.05399061 52.08567416 37.23393574 34.6588785  15.70067739 10.6260745
  9.59260217  6.26012829  5.71417565  3.15643885]
