In [264]:
import torch
import torch.nn as nn
from Memory.dataset import *
from Utils.utils import *
from torch.utils.data import Dataset,DataLoader
from sklearn.metrics import precision_score,recall_score

In [272]:
from torch.utils.tensorboard import SummaryWriter

# default `log_dir` is "runs" - we'll be more specific here
writer = SummaryWriter('TensorBoard/TrainAgent_RNN_OldData_5layer')

In [273]:
class Feeder(Dataset):
    def __init__(self,data_dir):
        self.data_dir = data_dir
        self.paths = os.listdir(data_dir)
        data = list()
        for file_dir in self.paths:
            f = open(data_dir + file_dir)
            try:
                js = json.load(f)['logs']
                data.extend(js)
            except json.JSONDecodeError:
                f.close()
                continue
            f.close()
        
#         self.data = list()
#         for idx,i in enumerate(data):
#             if decode_state_old_test(i,isImitation=True)[-1]  in ran:
#                 ip = list()
#                 try:
#                     for idx_2 in range(10,0,-1):
#                         ip.append(data[idx-idx_2])
#                 except IndexError:
#                     continue
#                 ip.append(i)
                
        self.data = data
             

    def __len__(self):
        return len(self.data) - 7

    def __getitem__(self, index):
        index = index + 6
        reward = Reward(self.data[index]['player_board_card_info'],self.data[index]['opponent_board_card_info'],self.data[index]['player_hand_card_id'],self.data[index]['opponent_life'],self.data[index]['player_life'],self.data[index]['player_gold'])
        return_reward = [max(reward[:56])] +[max(reward[56:105])] + [max(reward[105:114])]  + [max(reward[114:177])] + [max(reward[177:289])] + [reward[289]]          
        return_data = list()
        for i in self.data[index-4:index+1]:
            return_data.append(decode_state_old_test(i,isImitation=True)[:-1])
        return np.array(return_data),np.array(return_reward,dtype=np.float)


In [274]:
dataset = Feeder('Log/ImitationLog/Train/')

In [275]:
dataloader = DataLoader(dataset,batch_size=256,num_workers = 20)

In [276]:
def accuracy(pred, label):
    pred = torch.argmax(pred, dim=1).long()
#     print(pred)
    label = torch.argmax(label,dim=1).long()
    acc = torch.mean((pred == label).float())
    pred = pred.detach().cpu().numpy()
    label = label.detach().cpu().numpy()
    p = precision_score(label, pred,average='micro')
    r = recall_score(label, pred,average='micro')
    return p,r,acc 

class AverageMeter(object):

    def __init__(self):
        self.val = 0
        self.avg = 0
        self.sum = 0
        self.count = 0

    def reset(self):
        self.val = 0
        self.avg = 0
        self.sum = 0
        self.count = 0

    def update(self, val, n=1):
        self.val = val
        self.sum += val * n
        self.count += n
        self.avg = self.sum / self.count

In [281]:
class RNN(nn.Module):
    def __init__(self,input_size,output_size,hidden_size):
        super(RNN, self).__init__()

        self.dense1 = nn.Linear(input_size,1024)
        self.dense2 = nn.Linear(1024,512)

        self.rnn = nn.LSTM(
            input_size=512,
            hidden_size=hidden_size,
            num_layers=5,
            batch_first=True)
        self.dense3 = nn.Linear(hidden_size,output_size)

    def forward(self, x):
        dense1 = self.dense1(x.float())
        dense2 = self.dense2(dense1)
        out, (h_n, h_c) = self.rnn(dense2, None)
        out = self.dense3(out)
        return out[:, -1, :]

In [282]:
model = RNN(516,6,512).to(torch.device('cuda:2'))

In [None]:
crit = nn.BCEWithLogitsLoss().to(torch.device('cuda:2'))
losses = 0.0
accs  = 0.0
precisions  = AverageMeter()
recalls  = AverageMeter()

opt = torch.optim.SGD(model.parameters(), 0.01, 
                          momentum=0.9, 
                          weight_decay=1e-4) 
# checkpoint = torch.load('Models/1stImitationAgent_NewData_Oldecode.ckpt')
# model.load_state_dict(checkpoint)
for epoch in range(20):
    for i, (data,reward) in enumerate(dataloader):
        data_,reward_ = map(lambda x:x.to(torch.device('cuda:2')),(data,reward))

        pred = model.forward(data_.float())
#         y_onehot = label.detach().numpy()
#         y_onehot = (np.arange(56) == y_onehot[:,None]).astype(np.float32)
#         y_onehot = torch.from_numpy(y_onehot).to(torch.device('cuda:1'))
        loss = crit(pred,torch.nn.functional.softmax(reward_))
        p,r, acc = accuracy(pred, torch.nn.functional.softmax(reward_))

        opt.zero_grad()
        loss.backward()
        opt.step()

#         losses.update(loss.item(),530)
#         accs.update(acc.item(),530)
        losses += loss.item()
        accs += acc.item()
        precisions.update(p, 530)
        recalls.update(r,530)

        if i % 5 == 0:
            print('Epoch:[{0}][{1}/{2}]\t'
#                   'Loss {losses.val:.3f} ({losses.avg:.3f})\t'
#                   'Accuracy {accs.val:.3f} ({accs.avg:.3f})\t'
                  'Loss {losses:.3f} \t'
                  'Accuracy {accs:.3f}\t'
                  'Precison {precisions.val:.3f} ({precisions.avg:.3f})\t'
                  'Recall {recalls.val:.3f} ({recalls.avg:.3f})'.format(
                        epoch,i,len(dataloader) ,losses=loss, accs=acc, 
                        precisions=precisions, recalls=recalls))
        if i % 10 == 0:
            torch.save(model.state_dict(), 'Models/ImitationAgent_RNN_OldData_5Layer.ckpt')
            # ...log the running loss
            writer.add_scalar('training loss',
                            losses / 10,
                            epoch * len(dataloader) + i)

            # ...log a Matplotlib Figure showing the model's predictions on a
            # random mini-batch
            writer.add_scalar('accuracy',
                            accs / 10,
                            global_step=epoch * len(dataloader) + i)
            losses = 0.0
            accs = 0.0
    #         print('Epoch:[{0}][{1}/{2}]\t'
    #               'Loss {losses:.3f} \t'
    #               'Accuracy {accs:.3f}\t'
    #               'Precison {precisions:.3f}\t'
    #               'Recall {recalls:.3f}'.format(
    #                     epoch, losses=loss, accs=acc, 
    #                     precisions=



Epoch:[0][0/1224]	Loss 0.698 	Accuracy 0.352	Precison 0.352 (0.352)	Recall 0.352 (0.352)
Epoch:[0][5/1224]	Loss 0.694 	Accuracy 0.191	Precison 0.191 (0.288)	Recall 0.191 (0.288)
Epoch:[0][10/1224]	Loss 0.686 	Accuracy 0.613	Precison 0.613 (0.294)	Recall 0.613 (0.294)
Epoch:[0][15/1224]	Loss 0.676 	Accuracy 0.281	Precison 0.281 (0.269)	Recall 0.281 (0.269)
Epoch:[0][20/1224]	Loss 0.666 	Accuracy 0.316	Precison 0.316 (0.254)	Recall 0.316 (0.254)
Epoch:[0][25/1224]	Loss 0.652 	Accuracy 0.707	Precison 0.707 (0.307)	Recall 0.707 (0.307)
Epoch:[0][30/1224]	Loss 0.647 	Accuracy 0.000	Precison 0.000 (0.330)	Recall 0.000 (0.330)
Epoch:[0][35/1224]	Loss 0.628 	Accuracy 0.781	Precison 0.781 (0.349)	Recall 0.781 (0.349)
Epoch:[0][40/1224]	Loss 0.618 	Accuracy 0.547	Precison 0.547 (0.405)	Recall 0.547 (0.405)
Epoch:[0][45/1224]	Loss 0.609 	Accuracy 0.398	Precison 0.398 (0.413)	Recall 0.398 (0.413)
Epoch:[0][50/1224]	Loss 0.598 	Accuracy 0.344	Precison 0.344 (0.390)	Recall 0.344 (0.390)
Epoch:[0][55

Epoch:[0][455/1224]	Loss 0.494 	Accuracy 0.000	Precison 0.000 (0.421)	Recall 0.000 (0.421)
Epoch:[0][460/1224]	Loss 0.492 	Accuracy 0.000	Precison 0.000 (0.419)	Recall 0.000 (0.419)
Epoch:[0][465/1224]	Loss 0.346 	Accuracy 0.531	Precison 0.531 (0.420)	Recall 0.531 (0.420)
Epoch:[0][470/1224]	Loss 0.379 	Accuracy 0.000	Precison 0.000 (0.417)	Recall 0.000 (0.417)
Epoch:[0][475/1224]	Loss 0.287 	Accuracy 1.000	Precison 1.000 (0.417)	Recall 1.000 (0.417)
Epoch:[0][480/1224]	Loss 0.387 	Accuracy 0.551	Precison 0.551 (0.419)	Recall 0.551 (0.419)
Epoch:[0][485/1224]	Loss 0.361 	Accuracy 0.977	Precison 0.977 (0.422)	Recall 0.977 (0.422)
Epoch:[0][490/1224]	Loss 0.490 	Accuracy 0.012	Precison 0.012 (0.421)	Recall 0.012 (0.421)
Epoch:[0][495/1224]	Loss 0.354 	Accuracy 1.000	Precison 1.000 (0.426)	Recall 1.000 (0.426)
Epoch:[0][500/1224]	Loss 0.414 	Accuracy 0.996	Precison 0.996 (0.432)	Recall 0.996 (0.432)
Epoch:[0][505/1224]	Loss 0.333 	Accuracy 0.914	Precison 0.914 (0.437)	Recall 0.914 (0.437)

Epoch:[0][915/1224]	Loss 0.463 	Accuracy 1.000	Precison 1.000 (0.439)	Recall 1.000 (0.439)
Epoch:[0][920/1224]	Loss 0.378 	Accuracy 0.453	Precison 0.453 (0.441)	Recall 0.453 (0.441)
Epoch:[0][925/1224]	Loss 0.422 	Accuracy 0.289	Precison 0.289 (0.441)	Recall 0.289 (0.441)
Epoch:[0][930/1224]	Loss 0.422 	Accuracy 0.297	Precison 0.297 (0.441)	Recall 0.297 (0.441)
Epoch:[0][935/1224]	Loss 0.486 	Accuracy 0.027	Precison 0.027 (0.440)	Recall 0.027 (0.440)
Epoch:[0][940/1224]	Loss 0.423 	Accuracy 1.000	Precison 1.000 (0.441)	Recall 1.000 (0.441)
Epoch:[0][945/1224]	Loss 0.299 	Accuracy 1.000	Precison 1.000 (0.441)	Recall 1.000 (0.441)
Epoch:[0][950/1224]	Loss 0.495 	Accuracy 0.000	Precison 0.000 (0.439)	Recall 0.000 (0.439)
Epoch:[0][955/1224]	Loss 0.331 	Accuracy 0.738	Precison 0.738 (0.439)	Recall 0.738 (0.439)
Epoch:[0][960/1224]	Loss 0.309 	Accuracy 0.859	Precison 0.859 (0.439)	Recall 0.859 (0.439)
Epoch:[0][965/1224]	Loss 0.396 	Accuracy 0.230	Precison 0.230 (0.439)	Recall 0.230 (0.439)

Epoch:[1][140/1224]	Loss 0.503 	Accuracy 0.000	Precison 0.000 (0.449)	Recall 0.000 (0.449)
Epoch:[1][145/1224]	Loss 0.270 	Accuracy 0.973	Precison 0.973 (0.451)	Recall 0.973 (0.451)
Epoch:[1][150/1224]	Loss 0.485 	Accuracy 0.086	Precison 0.086 (0.450)	Recall 0.086 (0.450)
Epoch:[1][155/1224]	Loss 0.509 	Accuracy 0.000	Precison 0.000 (0.450)	Recall 0.000 (0.450)
Epoch:[1][160/1224]	Loss 0.385 	Accuracy 0.605	Precison 0.605 (0.450)	Recall 0.605 (0.450)
Epoch:[1][165/1224]	Loss 0.315 	Accuracy 0.750	Precison 0.750 (0.450)	Recall 0.750 (0.450)
Epoch:[1][170/1224]	Loss 0.307 	Accuracy 0.691	Precison 0.691 (0.450)	Recall 0.691 (0.450)
Epoch:[1][175/1224]	Loss 0.419 	Accuracy 1.000	Precison 1.000 (0.449)	Recall 1.000 (0.449)
Epoch:[1][180/1224]	Loss 0.506 	Accuracy 0.000	Precison 0.000 (0.448)	Recall 0.000 (0.448)
Epoch:[1][185/1224]	Loss 0.421 	Accuracy 0.016	Precison 0.016 (0.446)	Recall 0.016 (0.446)
Epoch:[1][190/1224]	Loss 0.389 	Accuracy 0.078	Precison 0.078 (0.446)	Recall 0.078 (0.446)

Epoch:[1][595/1224]	Loss 0.387 	Accuracy 0.328	Precison 0.328 (0.441)	Recall 0.328 (0.441)
Epoch:[1][600/1224]	Loss 0.387 	Accuracy 0.438	Precison 0.438 (0.442)	Recall 0.438 (0.442)
Epoch:[1][605/1224]	Loss 0.376 	Accuracy 0.645	Precison 0.645 (0.441)	Recall 0.645 (0.441)
Epoch:[1][610/1224]	Loss 0.392 	Accuracy 0.609	Precison 0.609 (0.442)	Recall 0.609 (0.442)
Epoch:[1][615/1224]	Loss 0.491 	Accuracy 0.000	Precison 0.000 (0.441)	Recall 0.000 (0.441)
Epoch:[1][620/1224]	Loss 0.265 	Accuracy 1.000	Precison 1.000 (0.442)	Recall 1.000 (0.442)
Epoch:[1][625/1224]	Loss 0.479 	Accuracy 0.000	Precison 0.000 (0.442)	Recall 0.000 (0.442)
Epoch:[1][630/1224]	Loss 0.397 	Accuracy 0.348	Precison 0.348 (0.442)	Recall 0.348 (0.442)
Epoch:[1][635/1224]	Loss 0.346 	Accuracy 0.879	Precison 0.879 (0.442)	Recall 0.879 (0.442)
Epoch:[1][640/1224]	Loss 0.408 	Accuracy 0.336	Precison 0.336 (0.443)	Recall 0.336 (0.443)
Epoch:[1][645/1224]	Loss 0.374 	Accuracy 0.473	Precison 0.473 (0.444)	Recall 0.473 (0.444)

Epoch:[1][1045/1224]	Loss 0.332 	Accuracy 0.770	Precison 0.770 (0.448)	Recall 0.770 (0.448)
Epoch:[1][1050/1224]	Loss 0.263 	Accuracy 0.996	Precison 0.996 (0.448)	Recall 0.996 (0.448)
Epoch:[1][1055/1224]	Loss 0.447 	Accuracy 0.309	Precison 0.309 (0.448)	Recall 0.309 (0.448)
Epoch:[1][1060/1224]	Loss 0.367 	Accuracy 0.574	Precison 0.574 (0.448)	Recall 0.574 (0.448)
Epoch:[1][1065/1224]	Loss 0.338 	Accuracy 0.699	Precison 0.699 (0.449)	Recall 0.699 (0.449)
Epoch:[1][1070/1224]	Loss 0.397 	Accuracy 0.348	Precison 0.348 (0.449)	Recall 0.348 (0.449)
Epoch:[1][1075/1224]	Loss 0.480 	Accuracy 0.000	Precison 0.000 (0.448)	Recall 0.000 (0.448)
Epoch:[1][1080/1224]	Loss 0.505 	Accuracy 0.000	Precison 0.000 (0.447)	Recall 0.000 (0.447)
Epoch:[1][1085/1224]	Loss 0.336 	Accuracy 0.633	Precison 0.633 (0.447)	Recall 0.633 (0.447)
Epoch:[1][1090/1224]	Loss 0.436 	Accuracy 0.172	Precison 0.172 (0.447)	Recall 0.172 (0.447)
Epoch:[1][1095/1224]	Loss 0.249 	Accuracy 1.000	Precison 1.000 (0.447)	Recall 1.

Epoch:[2][270/1224]	Loss 0.330 	Accuracy 0.660	Precison 0.660 (0.447)	Recall 0.660 (0.447)
Epoch:[2][275/1224]	Loss 0.365 	Accuracy 0.465	Precison 0.465 (0.447)	Recall 0.465 (0.447)
Epoch:[2][280/1224]	Loss 0.417 	Accuracy 0.449	Precison 0.449 (0.447)	Recall 0.449 (0.447)
Epoch:[2][285/1224]	Loss 0.349 	Accuracy 0.645	Precison 0.645 (0.447)	Recall 0.645 (0.447)
Epoch:[2][290/1224]	Loss 0.388 	Accuracy 0.219	Precison 0.219 (0.446)	Recall 0.219 (0.446)
Epoch:[2][295/1224]	Loss 0.382 	Accuracy 0.387	Precison 0.387 (0.446)	Recall 0.387 (0.446)
Epoch:[2][300/1224]	Loss 0.442 	Accuracy 0.379	Precison 0.379 (0.447)	Recall 0.379 (0.447)
Epoch:[2][305/1224]	Loss 0.384 	Accuracy 0.785	Precison 0.785 (0.447)	Recall 0.785 (0.447)
Epoch:[2][310/1224]	Loss 0.345 	Accuracy 0.961	Precison 0.961 (0.447)	Recall 0.961 (0.447)
Epoch:[2][315/1224]	Loss 0.360 	Accuracy 0.605	Precison 0.605 (0.448)	Recall 0.605 (0.448)
Epoch:[2][320/1224]	Loss 0.396 	Accuracy 0.449	Precison 0.449 (0.448)	Recall 0.449 (0.448)

Epoch:[2][725/1224]	Loss 0.395 	Accuracy 0.465	Precison 0.465 (0.447)	Recall 0.465 (0.447)
Epoch:[2][730/1224]	Loss 0.251 	Accuracy 1.000	Precison 1.000 (0.447)	Recall 1.000 (0.447)
Epoch:[2][735/1224]	Loss 0.479 	Accuracy 0.055	Precison 0.055 (0.448)	Recall 0.055 (0.448)
Epoch:[2][740/1224]	Loss 0.412 	Accuracy 0.477	Precison 0.477 (0.447)	Recall 0.477 (0.447)
Epoch:[2][745/1224]	Loss 0.429 	Accuracy 0.344	Precison 0.344 (0.447)	Recall 0.344 (0.447)
Epoch:[2][750/1224]	Loss 0.426 	Accuracy 0.281	Precison 0.281 (0.447)	Recall 0.281 (0.447)
Epoch:[2][755/1224]	Loss 0.490 	Accuracy 0.020	Precison 0.020 (0.447)	Recall 0.020 (0.447)
Epoch:[2][760/1224]	Loss 0.399 	Accuracy 0.383	Precison 0.383 (0.447)	Recall 0.383 (0.447)
Epoch:[2][765/1224]	Loss 0.255 	Accuracy 1.000	Precison 1.000 (0.447)	Recall 1.000 (0.447)
Epoch:[2][770/1224]	Loss 0.249 	Accuracy 1.000	Precison 1.000 (0.448)	Recall 1.000 (0.448)
Epoch:[2][775/1224]	Loss 0.410 	Accuracy 0.422	Precison 0.422 (0.448)	Recall 0.422 (0.448)

Epoch:[2][1175/1224]	Loss 0.284 	Accuracy 0.941	Precison 0.941 (0.449)	Recall 0.941 (0.449)
Epoch:[2][1180/1224]	Loss 0.368 	Accuracy 0.652	Precison 0.652 (0.448)	Recall 0.652 (0.448)
Epoch:[2][1185/1224]	Loss 0.399 	Accuracy 0.426	Precison 0.426 (0.448)	Recall 0.426 (0.448)
Epoch:[2][1190/1224]	Loss 0.313 	Accuracy 0.676	Precison 0.676 (0.448)	Recall 0.676 (0.448)
Epoch:[2][1195/1224]	Loss 0.474 	Accuracy 1.000	Precison 1.000 (0.449)	Recall 1.000 (0.449)
Epoch:[2][1200/1224]	Loss 0.324 	Accuracy 1.000	Precison 1.000 (0.449)	Recall 1.000 (0.449)
Epoch:[2][1205/1224]	Loss 0.496 	Accuracy 0.000	Precison 0.000 (0.449)	Recall 0.000 (0.449)
Epoch:[2][1210/1224]	Loss 0.369 	Accuracy 0.574	Precison 0.574 (0.449)	Recall 0.574 (0.449)
Epoch:[2][1215/1224]	Loss 0.463 	Accuracy 0.148	Precison 0.148 (0.449)	Recall 0.148 (0.449)
Epoch:[2][1220/1224]	Loss 0.411 	Accuracy 0.262	Precison 0.262 (0.449)	Recall 0.262 (0.449)
Epoch:[3][0/1224]	Loss 0.427 	Accuracy 0.395	Precison 0.395 (0.449)	Recall 0.395

Epoch:[3][405/1224]	Loss 0.367 	Accuracy 0.418	Precison 0.418 (0.448)	Recall 0.418 (0.448)
Epoch:[3][410/1224]	Loss 0.448 	Accuracy 0.000	Precison 0.000 (0.448)	Recall 0.000 (0.448)
Epoch:[3][415/1224]	Loss 0.480 	Accuracy 0.309	Precison 0.309 (0.447)	Recall 0.309 (0.447)
Epoch:[3][420/1224]	Loss 0.492 	Accuracy 0.062	Precison 0.062 (0.447)	Recall 0.062 (0.447)
Epoch:[3][425/1224]	Loss 0.286 	Accuracy 0.773	Precison 0.773 (0.447)	Recall 0.773 (0.447)
Epoch:[3][430/1224]	Loss 0.471 	Accuracy 0.203	Precison 0.203 (0.447)	Recall 0.203 (0.447)
Epoch:[3][435/1224]	Loss 0.384 	Accuracy 0.000	Precison 0.000 (0.447)	Recall 0.000 (0.447)
Epoch:[3][440/1224]	Loss 0.453 	Accuracy 0.312	Precison 0.312 (0.447)	Recall 0.312 (0.447)
Epoch:[3][445/1224]	Loss 0.501 	Accuracy 0.000	Precison 0.000 (0.446)	Recall 0.000 (0.446)
Epoch:[3][450/1224]	Loss 0.479 	Accuracy 0.105	Precison 0.105 (0.446)	Recall 0.105 (0.446)
Epoch:[3][455/1224]	Loss 0.495 	Accuracy 0.000	Precison 0.000 (0.446)	Recall 0.000 (0.446)

Epoch:[3][860/1224]	Loss 0.377 	Accuracy 0.508	Precison 0.508 (0.447)	Recall 0.508 (0.447)
Epoch:[3][865/1224]	Loss 0.414 	Accuracy 1.000	Precison 1.000 (0.448)	Recall 1.000 (0.448)
Epoch:[3][870/1224]	Loss 0.395 	Accuracy 0.000	Precison 0.000 (0.447)	Recall 0.000 (0.447)
Epoch:[3][875/1224]	Loss 0.309 	Accuracy 0.863	Precison 0.863 (0.448)	Recall 0.863 (0.448)
Epoch:[3][880/1224]	Loss 0.329 	Accuracy 0.758	Precison 0.758 (0.447)	Recall 0.758 (0.447)
Epoch:[3][885/1224]	Loss 0.393 	Accuracy 0.250	Precison 0.250 (0.447)	Recall 0.250 (0.447)
Epoch:[3][890/1224]	Loss 0.400 	Accuracy 0.352	Precison 0.352 (0.447)	Recall 0.352 (0.447)
Epoch:[3][895/1224]	Loss 0.399 	Accuracy 0.312	Precison 0.312 (0.447)	Recall 0.312 (0.447)
Epoch:[3][900/1224]	Loss 0.434 	Accuracy 0.023	Precison 0.023 (0.447)	Recall 0.023 (0.447)
Epoch:[3][905/1224]	Loss 0.409 	Accuracy 0.188	Precison 0.188 (0.447)	Recall 0.188 (0.447)
Epoch:[3][910/1224]	Loss 0.439 	Accuracy 0.918	Precison 0.918 (0.447)	Recall 0.918 (0.447)

Epoch:[4][85/1224]	Loss 0.405 	Accuracy 0.324	Precison 0.324 (0.449)	Recall 0.324 (0.449)
Epoch:[4][90/1224]	Loss 0.312 	Accuracy 0.645	Precison 0.645 (0.449)	Recall 0.645 (0.449)
Epoch:[4][95/1224]	Loss 0.373 	Accuracy 0.355	Precison 0.355 (0.449)	Recall 0.355 (0.449)
Epoch:[4][100/1224]	Loss 0.372 	Accuracy 0.973	Precison 0.973 (0.449)	Recall 0.973 (0.449)
Epoch:[4][105/1224]	Loss 0.409 	Accuracy 0.379	Precison 0.379 (0.449)	Recall 0.379 (0.449)
Epoch:[4][110/1224]	Loss 0.468 	Accuracy 0.238	Precison 0.238 (0.449)	Recall 0.238 (0.449)
Epoch:[4][115/1224]	Loss 0.477 	Accuracy 1.000	Precison 1.000 (0.449)	Recall 1.000 (0.449)
Epoch:[4][120/1224]	Loss 0.326 	Accuracy 0.695	Precison 0.695 (0.449)	Recall 0.695 (0.449)
Epoch:[4][125/1224]	Loss 0.349 	Accuracy 0.609	Precison 0.609 (0.449)	Recall 0.609 (0.449)
Epoch:[4][130/1224]	Loss 0.302 	Accuracy 0.738	Precison 0.738 (0.449)	Recall 0.738 (0.449)
Epoch:[4][135/1224]	Loss 0.356 	Accuracy 0.500	Precison 0.500 (0.449)	Recall 0.500 (0.449)
Ep

Epoch:[4][540/1224]	Loss 0.421 	Accuracy 0.223	Precison 0.223 (0.447)	Recall 0.223 (0.447)
Epoch:[4][545/1224]	Loss 0.374 	Accuracy 0.438	Precison 0.438 (0.447)	Recall 0.438 (0.447)
Epoch:[4][550/1224]	Loss 0.486 	Accuracy 0.000	Precison 0.000 (0.448)	Recall 0.000 (0.448)
Epoch:[4][555/1224]	Loss 0.468 	Accuracy 0.070	Precison 0.070 (0.447)	Recall 0.070 (0.447)
Epoch:[4][560/1224]	Loss 0.451 	Accuracy 0.164	Precison 0.164 (0.447)	Recall 0.164 (0.447)
Epoch:[4][565/1224]	Loss 0.485 	Accuracy 0.000	Precison 0.000 (0.447)	Recall 0.000 (0.447)
Epoch:[4][570/1224]	Loss 0.481 	Accuracy 0.000	Precison 0.000 (0.446)	Recall 0.000 (0.446)
Epoch:[4][575/1224]	Loss 0.305 	Accuracy 0.832	Precison 0.832 (0.446)	Recall 0.832 (0.446)
Epoch:[4][580/1224]	Loss 0.368 	Accuracy 0.438	Precison 0.438 (0.446)	Recall 0.438 (0.446)
Epoch:[4][585/1224]	Loss 0.283 	Accuracy 0.793	Precison 0.793 (0.446)	Recall 0.793 (0.446)
Epoch:[4][590/1224]	Loss 0.418 	Accuracy 0.199	Precison 0.199 (0.446)	Recall 0.199 (0.446)

Epoch:[4][995/1224]	Loss 0.304 	Accuracy 1.000	Precison 1.000 (0.448)	Recall 1.000 (0.448)
Epoch:[4][1000/1224]	Loss 0.299 	Accuracy 0.699	Precison 0.699 (0.448)	Recall 0.699 (0.448)
Epoch:[4][1005/1224]	Loss 0.451 	Accuracy 0.020	Precison 0.020 (0.448)	Recall 0.020 (0.448)
Epoch:[4][1010/1224]	Loss 0.505 	Accuracy 0.000	Precison 0.000 (0.448)	Recall 0.000 (0.448)
Epoch:[4][1015/1224]	Loss 0.119 	Accuracy 1.000	Precison 1.000 (0.448)	Recall 1.000 (0.448)
Epoch:[4][1020/1224]	Loss 0.400 	Accuracy 0.367	Precison 0.367 (0.448)	Recall 0.367 (0.448)
Epoch:[4][1025/1224]	Loss 0.363 	Accuracy 0.656	Precison 0.656 (0.448)	Recall 0.656 (0.448)
Epoch:[4][1030/1224]	Loss 0.510 	Accuracy 0.000	Precison 0.000 (0.448)	Recall 0.000 (0.448)
Epoch:[4][1035/1224]	Loss 0.375 	Accuracy 0.508	Precison 0.508 (0.448)	Recall 0.508 (0.448)
Epoch:[4][1040/1224]	Loss 0.482 	Accuracy 0.004	Precison 0.004 (0.448)	Recall 0.004 (0.448)
Epoch:[4][1045/1224]	Loss 0.284 	Accuracy 0.770	Precison 0.770 (0.448)	Recall 0.7

Epoch:[5][220/1224]	Loss 0.325 	Accuracy 1.000	Precison 1.000 (0.448)	Recall 1.000 (0.448)
Epoch:[5][225/1224]	Loss 0.311 	Accuracy 0.961	Precison 0.961 (0.448)	Recall 0.961 (0.448)
Epoch:[5][230/1224]	Loss 0.436 	Accuracy 0.402	Precison 0.402 (0.448)	Recall 0.402 (0.448)
Epoch:[5][235/1224]	Loss 0.423 	Accuracy 0.297	Precison 0.297 (0.448)	Recall 0.297 (0.448)
Epoch:[5][240/1224]	Loss 0.392 	Accuracy 0.000	Precison 0.000 (0.448)	Recall 0.000 (0.448)
Epoch:[5][245/1224]	Loss 0.585 	Accuracy 0.848	Precison 0.848 (0.448)	Recall 0.848 (0.448)
Epoch:[5][250/1224]	Loss 0.457 	Accuracy 0.137	Precison 0.137 (0.448)	Recall 0.137 (0.448)
Epoch:[5][255/1224]	Loss 0.446 	Accuracy 0.223	Precison 0.223 (0.448)	Recall 0.223 (0.448)
Epoch:[5][260/1224]	Loss 0.388 	Accuracy 0.387	Precison 0.387 (0.448)	Recall 0.387 (0.448)
Epoch:[5][265/1224]	Loss 0.277 	Accuracy 0.707	Precison 0.707 (0.448)	Recall 0.707 (0.448)
Epoch:[5][270/1224]	Loss 0.264 	Accuracy 0.688	Precison 0.688 (0.448)	Recall 0.688 (0.448)

Epoch:[5][675/1224]	Loss 0.337 	Accuracy 0.758	Precison 0.758 (0.449)	Recall 0.758 (0.449)
Epoch:[5][680/1224]	Loss 0.435 	Accuracy 0.969	Precison 0.969 (0.449)	Recall 0.969 (0.449)
Epoch:[5][685/1224]	Loss 0.482 	Accuracy 0.000	Precison 0.000 (0.449)	Recall 0.000 (0.449)
Epoch:[5][690/1224]	Loss 0.454 	Accuracy 0.934	Precison 0.934 (0.449)	Recall 0.934 (0.449)
Epoch:[5][695/1224]	Loss 0.352 	Accuracy 0.559	Precison 0.559 (0.449)	Recall 0.559 (0.449)
Epoch:[5][700/1224]	Loss 0.407 	Accuracy 0.543	Precison 0.543 (0.449)	Recall 0.543 (0.449)
Epoch:[5][705/1224]	Loss 0.362 	Accuracy 0.543	Precison 0.543 (0.449)	Recall 0.543 (0.449)
Epoch:[5][710/1224]	Loss 0.432 	Accuracy 0.195	Precison 0.195 (0.449)	Recall 0.195 (0.449)
Epoch:[5][715/1224]	Loss 0.477 	Accuracy 0.223	Precison 0.223 (0.449)	Recall 0.223 (0.449)
Epoch:[5][720/1224]	Loss 0.483 	Accuracy 0.000	Precison 0.000 (0.449)	Recall 0.000 (0.449)
Epoch:[5][725/1224]	Loss 0.395 	Accuracy 0.414	Precison 0.414 (0.449)	Recall 0.414 (0.449)

Epoch:[5][1125/1224]	Loss 0.475 	Accuracy 0.098	Precison 0.098 (0.449)	Recall 0.098 (0.449)
Epoch:[5][1130/1224]	Loss 0.457 	Accuracy 0.176	Precison 0.176 (0.449)	Recall 0.176 (0.449)
Epoch:[5][1135/1224]	Loss 0.423 	Accuracy 1.000	Precison 1.000 (0.449)	Recall 1.000 (0.449)
Epoch:[5][1140/1224]	Loss 0.419 	Accuracy 1.000	Precison 1.000 (0.450)	Recall 1.000 (0.450)
Epoch:[5][1145/1224]	Loss 0.382 	Accuracy 0.480	Precison 0.480 (0.450)	Recall 0.480 (0.450)
Epoch:[5][1150/1224]	Loss 0.357 	Accuracy 0.551	Precison 0.551 (0.450)	Recall 0.551 (0.450)
Epoch:[5][1155/1224]	Loss 0.422 	Accuracy 0.293	Precison 0.293 (0.450)	Recall 0.293 (0.450)
Epoch:[5][1160/1224]	Loss 0.453 	Accuracy 0.145	Precison 0.145 (0.450)	Recall 0.145 (0.450)
Epoch:[5][1165/1224]	Loss 0.346 	Accuracy 1.000	Precison 1.000 (0.450)	Recall 1.000 (0.450)
Epoch:[5][1170/1224]	Loss 0.501 	Accuracy 0.000	Precison 0.000 (0.450)	Recall 0.000 (0.450)
Epoch:[5][1175/1224]	Loss 0.309 	Accuracy 0.941	Precison 0.941 (0.450)	Recall 0.

Epoch:[6][355/1224]	Loss 0.365 	Accuracy 0.516	Precison 0.516 (0.452)	Recall 0.516 (0.452)
Epoch:[6][360/1224]	Loss 0.315 	Accuracy 0.930	Precison 0.930 (0.452)	Recall 0.930 (0.452)
Epoch:[6][365/1224]	Loss 0.536 	Accuracy 0.000	Precison 0.000 (0.451)	Recall 0.000 (0.451)
Epoch:[6][370/1224]	Loss 0.390 	Accuracy 0.426	Precison 0.426 (0.452)	Recall 0.426 (0.452)
Epoch:[6][375/1224]	Loss 0.367 	Accuracy 0.441	Precison 0.441 (0.452)	Recall 0.441 (0.452)
Epoch:[6][380/1224]	Loss 0.388 	Accuracy 0.379	Precison 0.379 (0.451)	Recall 0.379 (0.451)
Epoch:[6][385/1224]	Loss 0.341 	Accuracy 0.617	Precison 0.617 (0.451)	Recall 0.617 (0.451)
Epoch:[6][390/1224]	Loss 0.398 	Accuracy 0.430	Precison 0.430 (0.451)	Recall 0.430 (0.451)
Epoch:[6][395/1224]	Loss 0.296 	Accuracy 0.508	Precison 0.508 (0.451)	Recall 0.508 (0.451)
Epoch:[6][400/1224]	Loss 0.347 	Accuracy 0.480	Precison 0.480 (0.451)	Recall 0.480 (0.451)
Epoch:[6][405/1224]	Loss 0.336 	Accuracy 0.645	Precison 0.645 (0.451)	Recall 0.645 (0.451)

Epoch:[6][810/1224]	Loss 0.487 	Accuracy 1.000	Precison 1.000 (0.453)	Recall 1.000 (0.453)
Epoch:[6][815/1224]	Loss 0.418 	Accuracy 0.375	Precison 0.375 (0.453)	Recall 0.375 (0.453)
Epoch:[6][820/1224]	Loss 0.300 	Accuracy 0.609	Precison 0.609 (0.453)	Recall 0.609 (0.453)
Epoch:[6][825/1224]	Loss 0.491 	Accuracy 0.000	Precison 0.000 (0.452)	Recall 0.000 (0.452)
Epoch:[6][830/1224]	Loss 0.385 	Accuracy 0.398	Precison 0.398 (0.452)	Recall 0.398 (0.452)
Epoch:[6][835/1224]	Loss 0.430 	Accuracy 0.242	Precison 0.242 (0.452)	Recall 0.242 (0.452)
Epoch:[6][840/1224]	Loss 0.326 	Accuracy 1.000	Precison 1.000 (0.452)	Recall 1.000 (0.452)
Epoch:[6][845/1224]	Loss 0.408 	Accuracy 0.352	Precison 0.352 (0.452)	Recall 0.352 (0.452)
Epoch:[6][850/1224]	Loss 0.473 	Accuracy 1.000	Precison 1.000 (0.452)	Recall 1.000 (0.452)
Epoch:[6][855/1224]	Loss 0.483 	Accuracy 0.000	Precison 0.000 (0.452)	Recall 0.000 (0.452)
Epoch:[6][860/1224]	Loss 0.368 	Accuracy 0.555	Precison 0.555 (0.452)	Recall 0.555 (0.452)

Epoch:[7][35/1224]	Loss 0.383 	Accuracy 0.777	Precison 0.777 (0.453)	Recall 0.777 (0.453)
Epoch:[7][40/1224]	Loss 0.396 	Accuracy 0.500	Precison 0.500 (0.453)	Recall 0.500 (0.453)
Epoch:[7][45/1224]	Loss 0.398 	Accuracy 0.414	Precison 0.414 (0.453)	Recall 0.414 (0.453)
Epoch:[7][50/1224]	Loss 0.397 	Accuracy 0.387	Precison 0.387 (0.453)	Recall 0.387 (0.453)
Epoch:[7][55/1224]	Loss 0.408 	Accuracy 0.367	Precison 0.367 (0.453)	Recall 0.367 (0.453)
Epoch:[7][60/1224]	Loss 0.345 	Accuracy 0.559	Precison 0.559 (0.453)	Recall 0.559 (0.453)
Epoch:[7][65/1224]	Loss 0.327 	Accuracy 0.945	Precison 0.945 (0.453)	Recall 0.945 (0.453)
Epoch:[7][70/1224]	Loss 0.374 	Accuracy 0.586	Precison 0.586 (0.453)	Recall 0.586 (0.453)
Epoch:[7][75/1224]	Loss 0.370 	Accuracy 0.434	Precison 0.434 (0.453)	Recall 0.434 (0.453)
Epoch:[7][80/1224]	Loss 0.419 	Accuracy 0.328	Precison 0.328 (0.453)	Recall 0.328 (0.453)
Epoch:[7][85/1224]	Loss 0.388 	Accuracy 0.543	Precison 0.543 (0.453)	Recall 0.543 (0.453)
Epoch:[7][

Epoch:[7][490/1224]	Loss 0.494 	Accuracy 0.016	Precison 0.016 (0.454)	Recall 0.016 (0.454)
Epoch:[7][495/1224]	Loss 0.303 	Accuracy 1.000	Precison 1.000 (0.454)	Recall 1.000 (0.454)
Epoch:[7][500/1224]	Loss 0.425 	Accuracy 0.996	Precison 0.996 (0.454)	Recall 0.996 (0.454)
Epoch:[7][505/1224]	Loss 0.263 	Accuracy 0.945	Precison 0.945 (0.455)	Recall 0.945 (0.455)
Epoch:[7][510/1224]	Loss 0.407 	Accuracy 0.320	Precison 0.320 (0.455)	Recall 0.320 (0.455)
Epoch:[7][515/1224]	Loss 0.502 	Accuracy 0.152	Precison 0.152 (0.455)	Recall 0.152 (0.455)
Epoch:[7][520/1224]	Loss 0.271 	Accuracy 0.770	Precison 0.770 (0.455)	Recall 0.770 (0.455)
Epoch:[7][525/1224]	Loss 0.412 	Accuracy 0.422	Precison 0.422 (0.455)	Recall 0.422 (0.455)
Epoch:[7][530/1224]	Loss 0.493 	Accuracy 0.000	Precison 0.000 (0.455)	Recall 0.000 (0.455)
Epoch:[7][535/1224]	Loss 0.494 	Accuracy 0.000	Precison 0.000 (0.454)	Recall 0.000 (0.454)
Epoch:[7][540/1224]	Loss 0.387 	Accuracy 0.379	Precison 0.379 (0.454)	Recall 0.379 (0.454)

Epoch:[7][945/1224]	Loss 0.273 	Accuracy 1.000	Precison 1.000 (0.455)	Recall 1.000 (0.455)
Epoch:[7][950/1224]	Loss 0.506 	Accuracy 0.000	Precison 0.000 (0.455)	Recall 0.000 (0.455)
Epoch:[7][955/1224]	Loss 0.356 	Accuracy 0.742	Precison 0.742 (0.455)	Recall 0.742 (0.455)
Epoch:[7][960/1224]	Loss 0.259 	Accuracy 0.859	Precison 0.859 (0.455)	Recall 0.859 (0.455)
Epoch:[7][965/1224]	Loss 0.463 	Accuracy 0.230	Precison 0.230 (0.455)	Recall 0.230 (0.455)
Epoch:[7][970/1224]	Loss 0.332 	Accuracy 0.617	Precison 0.617 (0.455)	Recall 0.617 (0.455)
Epoch:[7][975/1224]	Loss 0.254 	Accuracy 0.914	Precison 0.914 (0.455)	Recall 0.914 (0.455)
Epoch:[7][980/1224]	Loss 0.285 	Accuracy 1.000	Precison 1.000 (0.456)	Recall 1.000 (0.456)
Epoch:[7][985/1224]	Loss 0.267 	Accuracy 0.688	Precison 0.688 (0.456)	Recall 0.688 (0.456)
Epoch:[7][990/1224]	Loss 0.426 	Accuracy 0.289	Precison 0.289 (0.456)	Recall 0.289 (0.456)
Epoch:[7][995/1224]	Loss 0.309 	Accuracy 1.000	Precison 1.000 (0.456)	Recall 1.000 (0.456)

Epoch:[8][170/1224]	Loss 0.289 	Accuracy 0.797	Precison 0.797 (0.456)	Recall 0.797 (0.456)
Epoch:[8][175/1224]	Loss 0.453 	Accuracy 0.000	Precison 0.000 (0.456)	Recall 0.000 (0.456)
Epoch:[8][180/1224]	Loss 0.495 	Accuracy 0.000	Precison 0.000 (0.456)	Recall 0.000 (0.456)
Epoch:[8][185/1224]	Loss 0.350 	Accuracy 0.602	Precison 0.602 (0.455)	Recall 0.602 (0.455)
Epoch:[8][190/1224]	Loss 0.362 	Accuracy 0.770	Precison 0.770 (0.456)	Recall 0.770 (0.456)
Epoch:[8][195/1224]	Loss 0.451 	Accuracy 0.168	Precison 0.168 (0.456)	Recall 0.168 (0.456)
Epoch:[8][200/1224]	Loss 0.332 	Accuracy 0.633	Precison 0.633 (0.456)	Recall 0.633 (0.456)
Epoch:[8][205/1224]	Loss 0.377 	Accuracy 0.828	Precison 0.828 (0.456)	Recall 0.828 (0.456)
Epoch:[8][210/1224]	Loss 0.198 	Accuracy 1.000	Precison 1.000 (0.456)	Recall 1.000 (0.456)
Epoch:[8][215/1224]	Loss 0.487 	Accuracy 0.000	Precison 0.000 (0.456)	Recall 0.000 (0.456)
Epoch:[8][220/1224]	Loss 0.222 	Accuracy 1.000	Precison 1.000 (0.456)	Recall 1.000 (0.456)

Epoch:[8][625/1224]	Loss 0.483 	Accuracy 0.000	Precison 0.000 (0.456)	Recall 0.000 (0.456)
Epoch:[8][630/1224]	Loss 0.375 	Accuracy 0.402	Precison 0.402 (0.456)	Recall 0.402 (0.456)
Epoch:[8][635/1224]	Loss 0.307 	Accuracy 0.906	Precison 0.906 (0.456)	Recall 0.906 (0.456)
Epoch:[8][640/1224]	Loss 0.378 	Accuracy 0.430	Precison 0.430 (0.457)	Recall 0.430 (0.457)
Epoch:[8][645/1224]	Loss 0.371 	Accuracy 0.488	Precison 0.488 (0.457)	Recall 0.488 (0.457)
Epoch:[8][650/1224]	Loss 0.316 	Accuracy 0.602	Precison 0.602 (0.457)	Recall 0.602 (0.457)
Epoch:[8][655/1224]	Loss 0.455 	Accuracy 0.195	Precison 0.195 (0.457)	Recall 0.195 (0.457)
Epoch:[8][660/1224]	Loss 0.094 	Accuracy 1.000	Precison 1.000 (0.457)	Recall 1.000 (0.457)
Epoch:[8][665/1224]	Loss 0.089 	Accuracy 1.000	Precison 1.000 (0.457)	Recall 1.000 (0.457)
Epoch:[8][670/1224]	Loss 0.453 	Accuracy 0.156	Precison 0.156 (0.457)	Recall 0.156 (0.457)
Epoch:[8][675/1224]	Loss 0.302 	Accuracy 0.758	Precison 0.758 (0.457)	Recall 0.758 (0.457)

Epoch:[8][1075/1224]	Loss 0.439 	Accuracy 0.023	Precison 0.023 (0.458)	Recall 0.023 (0.458)
Epoch:[8][1080/1224]	Loss 0.486 	Accuracy 0.000	Precison 0.000 (0.457)	Recall 0.000 (0.457)
Epoch:[8][1085/1224]	Loss 0.308 	Accuracy 0.660	Precison 0.660 (0.457)	Recall 0.660 (0.457)
Epoch:[8][1090/1224]	Loss 0.460 	Accuracy 0.320	Precison 0.320 (0.457)	Recall 0.320 (0.457)
Epoch:[8][1095/1224]	Loss 0.348 	Accuracy 0.000	Precison 0.000 (0.457)	Recall 0.000 (0.457)
Epoch:[8][1100/1224]	Loss 0.444 	Accuracy 0.355	Precison 0.355 (0.457)	Recall 0.355 (0.457)
Epoch:[8][1105/1224]	Loss 0.542 	Accuracy 0.086	Precison 0.086 (0.457)	Recall 0.086 (0.457)
Epoch:[8][1110/1224]	Loss 0.308 	Accuracy 0.766	Precison 0.766 (0.457)	Recall 0.766 (0.457)
Epoch:[8][1115/1224]	Loss 0.385 	Accuracy 0.398	Precison 0.398 (0.457)	Recall 0.398 (0.457)
Epoch:[8][1120/1224]	Loss 0.362 	Accuracy 0.574	Precison 0.574 (0.457)	Recall 0.574 (0.457)
Epoch:[8][1125/1224]	Loss 0.480 	Accuracy 0.113	Precison 0.113 (0.457)	Recall 0.

Epoch:[9][305/1224]	Loss 0.369 	Accuracy 0.812	Precison 0.812 (0.458)	Recall 0.812 (0.458)
Epoch:[9][310/1224]	Loss 0.325 	Accuracy 0.969	Precison 0.969 (0.459)	Recall 0.969 (0.459)
Epoch:[9][315/1224]	Loss 0.314 	Accuracy 0.809	Precison 0.809 (0.459)	Recall 0.809 (0.459)
Epoch:[9][320/1224]	Loss 0.397 	Accuracy 0.430	Precison 0.430 (0.459)	Recall 0.430 (0.459)
Epoch:[9][325/1224]	Loss 0.370 	Accuracy 0.418	Precison 0.418 (0.459)	Recall 0.418 (0.459)
Epoch:[9][330/1224]	Loss 0.096 	Accuracy 1.000	Precison 1.000 (0.459)	Recall 1.000 (0.459)
Epoch:[9][335/1224]	Loss 0.334 	Accuracy 0.570	Precison 0.570 (0.459)	Recall 0.570 (0.459)
Epoch:[9][340/1224]	Loss 0.488 	Accuracy 0.066	Precison 0.066 (0.459)	Recall 0.066 (0.459)
Epoch:[9][345/1224]	Loss 0.340 	Accuracy 0.473	Precison 0.473 (0.459)	Recall 0.473 (0.459)
Epoch:[9][350/1224]	Loss 0.340 	Accuracy 0.605	Precison 0.605 (0.459)	Recall 0.605 (0.459)
Epoch:[9][355/1224]	Loss 0.359 	Accuracy 0.551	Precison 0.551 (0.459)	Recall 0.551 (0.459)

Epoch:[9][765/1224]	Loss 0.069 	Accuracy 1.000	Precison 1.000 (0.459)	Recall 1.000 (0.459)
Epoch:[9][770/1224]	Loss 0.089 	Accuracy 1.000	Precison 1.000 (0.459)	Recall 1.000 (0.459)
Epoch:[9][775/1224]	Loss 0.419 	Accuracy 0.559	Precison 0.559 (0.459)	Recall 0.559 (0.459)
Epoch:[9][780/1224]	Loss 0.226 	Accuracy 0.938	Precison 0.938 (0.460)	Recall 0.938 (0.460)
Epoch:[9][785/1224]	Loss 0.370 	Accuracy 0.652	Precison 0.652 (0.460)	Recall 0.652 (0.460)
Epoch:[9][790/1224]	Loss 0.416 	Accuracy 1.000	Precison 1.000 (0.460)	Recall 1.000 (0.460)
Epoch:[9][795/1224]	Loss 0.422 	Accuracy 0.328	Precison 0.328 (0.460)	Recall 0.328 (0.460)
Epoch:[9][800/1224]	Loss 0.498 	Accuracy 0.148	Precison 0.148 (0.460)	Recall 0.148 (0.460)
Epoch:[9][805/1224]	Loss 0.320 	Accuracy 0.914	Precison 0.914 (0.460)	Recall 0.914 (0.460)
Epoch:[9][810/1224]	Loss 0.519 	Accuracy 1.000	Precison 1.000 (0.460)	Recall 1.000 (0.460)
Epoch:[9][815/1224]	Loss 0.422 	Accuracy 0.262	Precison 0.262 (0.460)	Recall 0.262 (0.460)

Epoch:[9][1215/1224]	Loss 0.460 	Accuracy 0.121	Precison 0.121 (0.460)	Recall 0.121 (0.460)
Epoch:[9][1220/1224]	Loss 0.384 	Accuracy 0.418	Precison 0.418 (0.460)	Recall 0.418 (0.460)
Epoch:[10][0/1224]	Loss 0.423 	Accuracy 0.367	Precison 0.367 (0.460)	Recall 0.367 (0.460)
Epoch:[10][5/1224]	Loss 0.360 	Accuracy 0.496	Precison 0.496 (0.460)	Recall 0.496 (0.460)
Epoch:[10][10/1224]	Loss 0.365 	Accuracy 0.758	Precison 0.758 (0.460)	Recall 0.758 (0.460)
Epoch:[10][15/1224]	Loss 0.442 	Accuracy 0.109	Precison 0.109 (0.460)	Recall 0.109 (0.460)
Epoch:[10][20/1224]	Loss 0.446 	Accuracy 0.305	Precison 0.305 (0.460)	Recall 0.305 (0.460)
Epoch:[10][25/1224]	Loss 0.387 	Accuracy 0.605	Precison 0.605 (0.460)	Recall 0.605 (0.460)
Epoch:[10][30/1224]	Loss 0.516 	Accuracy 0.000	Precison 0.000 (0.460)	Recall 0.000 (0.460)
Epoch:[10][35/1224]	Loss 0.381 	Accuracy 0.785	Precison 0.785 (0.460)	Recall 0.785 (0.460)
Epoch:[10][40/1224]	Loss 0.393 	Accuracy 0.520	Precison 0.520 (0.460)	Recall 0.520 (0.460)

Epoch:[10][445/1224]	Loss 0.495 	Accuracy 0.000	Precison 0.000 (0.460)	Recall 0.000 (0.460)
Epoch:[10][450/1224]	Loss 0.596 	Accuracy 0.105	Precison 0.105 (0.461)	Recall 0.105 (0.461)
Epoch:[10][455/1224]	Loss 0.489 	Accuracy 0.000	Precison 0.000 (0.460)	Recall 0.000 (0.460)
Epoch:[10][460/1224]	Loss 0.492 	Accuracy 0.000	Precison 0.000 (0.460)	Recall 0.000 (0.460)
Epoch:[10][465/1224]	Loss 0.325 	Accuracy 0.602	Precison 0.602 (0.460)	Recall 0.602 (0.460)
Epoch:[10][470/1224]	Loss 0.367 	Accuracy 0.000	Precison 0.000 (0.460)	Recall 0.000 (0.460)
Epoch:[10][475/1224]	Loss 0.194 	Accuracy 1.000	Precison 1.000 (0.460)	Recall 1.000 (0.460)
Epoch:[10][480/1224]	Loss 0.369 	Accuracy 0.570	Precison 0.570 (0.460)	Recall 0.570 (0.460)
Epoch:[10][485/1224]	Loss 0.312 	Accuracy 0.977	Precison 0.977 (0.460)	Recall 0.977 (0.460)
Epoch:[10][490/1224]	Loss 0.485 	Accuracy 0.023	Precison 0.023 (0.460)	Recall 0.023 (0.460)
Epoch:[10][495/1224]	Loss 0.308 	Accuracy 1.000	Precison 1.000 (0.461)	Recall 1.

Epoch:[10][895/1224]	Loss 0.397 	Accuracy 0.352	Precison 0.352 (0.461)	Recall 0.352 (0.461)
Epoch:[10][900/1224]	Loss 0.438 	Accuracy 0.043	Precison 0.043 (0.461)	Recall 0.043 (0.461)
Epoch:[10][905/1224]	Loss 0.382 	Accuracy 0.410	Precison 0.410 (0.461)	Recall 0.410 (0.461)
Epoch:[10][910/1224]	Loss 0.429 	Accuracy 0.914	Precison 0.914 (0.461)	Recall 0.914 (0.461)
Epoch:[10][915/1224]	Loss 0.458 	Accuracy 1.000	Precison 1.000 (0.461)	Recall 1.000 (0.461)
Epoch:[10][920/1224]	Loss 0.371 	Accuracy 0.449	Precison 0.449 (0.461)	Recall 0.449 (0.461)
Epoch:[10][925/1224]	Loss 0.424 	Accuracy 0.273	Precison 0.273 (0.461)	Recall 0.273 (0.461)
Epoch:[10][930/1224]	Loss 0.418 	Accuracy 0.234	Precison 0.234 (0.461)	Recall 0.234 (0.461)
Epoch:[10][935/1224]	Loss 0.478 	Accuracy 0.059	Precison 0.059 (0.461)	Recall 0.059 (0.461)
Epoch:[10][940/1224]	Loss 0.418 	Accuracy 1.000	Precison 1.000 (0.461)	Recall 1.000 (0.461)
Epoch:[10][945/1224]	Loss 0.281 	Accuracy 1.000	Precison 1.000 (0.461)	Recall 1.

Epoch:[11][115/1224]	Loss 0.477 	Accuracy 0.000	Precison 0.000 (0.461)	Recall 0.000 (0.461)
Epoch:[11][120/1224]	Loss 0.272 	Accuracy 0.707	Precison 0.707 (0.461)	Recall 0.707 (0.461)
Epoch:[11][125/1224]	Loss 0.336 	Accuracy 0.617	Precison 0.617 (0.461)	Recall 0.617 (0.461)
Epoch:[11][130/1224]	Loss 0.251 	Accuracy 0.730	Precison 0.730 (0.461)	Recall 0.730 (0.461)
Epoch:[11][135/1224]	Loss 0.360 	Accuracy 0.562	Precison 0.562 (0.461)	Recall 0.562 (0.461)
Epoch:[11][140/1224]	Loss 0.491 	Accuracy 0.000	Precison 0.000 (0.461)	Recall 0.000 (0.461)
Epoch:[11][145/1224]	Loss 0.220 	Accuracy 0.965	Precison 0.965 (0.462)	Recall 0.965 (0.462)
Epoch:[11][150/1224]	Loss 0.493 	Accuracy 0.090	Precison 0.090 (0.462)	Recall 0.090 (0.462)
Epoch:[11][155/1224]	Loss 0.554 	Accuracy 0.000	Precison 0.000 (0.462)	Recall 0.000 (0.462)
Epoch:[11][160/1224]	Loss 0.374 	Accuracy 0.605	Precison 0.605 (0.462)	Recall 0.605 (0.462)
Epoch:[11][165/1224]	Loss 0.240 	Accuracy 0.801	Precison 0.801 (0.462)	Recall 0.

Epoch:[11][565/1224]	Loss 0.484 	Accuracy 0.000	Precison 0.000 (0.462)	Recall 0.000 (0.462)
Epoch:[11][570/1224]	Loss 0.481 	Accuracy 0.000	Precison 0.000 (0.462)	Recall 0.000 (0.462)
Epoch:[11][575/1224]	Loss 0.269 	Accuracy 0.848	Precison 0.848 (0.462)	Recall 0.848 (0.462)
Epoch:[11][580/1224]	Loss 0.346 	Accuracy 0.488	Precison 0.488 (0.462)	Recall 0.488 (0.462)
Epoch:[11][585/1224]	Loss 0.218 	Accuracy 0.785	Precison 0.785 (0.462)	Recall 0.785 (0.462)
Epoch:[11][590/1224]	Loss 0.423 	Accuracy 0.238	Precison 0.238 (0.462)	Recall 0.238 (0.462)
Epoch:[11][595/1224]	Loss 0.364 	Accuracy 0.492	Precison 0.492 (0.462)	Recall 0.492 (0.462)
Epoch:[11][600/1224]	Loss 0.348 	Accuracy 0.438	Precison 0.438 (0.462)	Recall 0.438 (0.462)
Epoch:[11][605/1224]	Loss 0.391 	Accuracy 0.379	Precison 0.379 (0.462)	Recall 0.379 (0.462)
Epoch:[11][610/1224]	Loss 0.388 	Accuracy 0.645	Precison 0.645 (0.462)	Recall 0.645 (0.462)
Epoch:[11][615/1224]	Loss 0.478 	Accuracy 0.000	Precison 0.000 (0.462)	Recall 0.

Epoch:[11][1015/1224]	Loss 0.071 	Accuracy 1.000	Precison 1.000 (0.463)	Recall 1.000 (0.463)
Epoch:[11][1020/1224]	Loss 0.382 	Accuracy 0.414	Precison 0.414 (0.463)	Recall 0.414 (0.463)
Epoch:[11][1025/1224]	Loss 0.355 	Accuracy 0.598	Precison 0.598 (0.463)	Recall 0.598 (0.463)
Epoch:[11][1030/1224]	Loss 0.486 	Accuracy 0.000	Precison 0.000 (0.463)	Recall 0.000 (0.463)
Epoch:[11][1035/1224]	Loss 0.361 	Accuracy 0.473	Precison 0.473 (0.463)	Recall 0.473 (0.463)


In [262]:
len(dataloader)

1224

In [240]:
model = RNN(516,290,512)

In [243]:
model.forward(data)

torch.LongTensor


tensor([[-0.0257, -0.0630, -0.0623,  0.1361, -0.0200, -0.0039,  0.0120,  0.0412,
         -0.0994,  0.0217, -0.0275,  0.1178,  0.0206, -0.0918, -0.0219, -0.0159,
         -0.0219, -0.0197, -0.1264, -0.1075,  0.1100,  0.0266, -0.0780, -0.1267,
         -0.0762, -0.0070, -0.0137,  0.0014, -0.1870, -0.0176, -0.0502, -0.0260,
          0.0059,  0.0530, -0.0432,  0.0633, -0.0933,  0.0084,  0.0375,  0.0401,
         -0.0171, -0.0616, -0.0005, -0.0459,  0.0608,  0.0465,  0.0005,  0.0145,
          0.0418, -0.1062, -0.0940,  0.0567,  0.0578,  0.0201,  0.0536,  0.1370,
         -0.0172,  0.0106,  0.0296,  0.0684,  0.0232,  0.0435, -0.0316,  0.0069,
         -0.0825, -0.1120, -0.0722,  0.0145, -0.0406, -0.0420,  0.0326,  0.0100,
          0.0343, -0.0009, -0.0210,  0.0863,  0.0389, -0.0674, -0.1065,  0.0212,
          0.0376,  0.0009,  0.0232,  0.0349,  0.0272,  0.0144,  0.0772, -0.0026,
          0.1036, -0.0863, -0.0276, -0.0064,  0.0217, -0.0501, -0.0400, -0.0052,
          0.1436, -0.0073, -

In [135]:
batch_size = 1
seq_len = 3

inp = torch.randn(batch_size, seq_len, 10).to(torch.device('cuda:0'))
h = a.init_hidden(batch_size)
h = tuple([e.data for e in h])


In [136]:
out,h = a.forward(inp,h)

torch.Size([1, 3, 10])
torch.Size([1, 3, 1024])
torch.Size([1, 3, 512])


In [138]:
out.shape

torch.Size([1, 60])

In [130]:
out.shape

torch.Size([1, 60])

In [139]:
import torch
import torch.nn as nn
import numpy as np

EPOCHS = 500
IN_SIZE = 5
NUM_SAMPLES = 5

def generate_data(rows, columns, samples):
    X = []
    y = []
    transformations = {
        '11': lambda x, y: x + y,
        '15': lambda x, y: x - y,
        '10': lambda x, y: x * y,
        '30': lambda x, y: x / y,
        '2': lambda x, y: x + y,
        }
    for j in range(samples):
        data_set = []
        for i in range(columns):
            data = []
            for val, fn in transformations.items():
                data.append(int(fn(int(val), i+j+1)))
            data_set.append(data)
        X.append(data_set)
        y.append([j+1])
    return X, y


class RNN(nn.Module):
    def __init__(self):
        super(RNN, self).__init__()

        self.rnn = nn.LSTM(
            input_size=5,
            hidden_size=NUM_SAMPLES+1,
            num_layers=2,
            batch_first=True,
    )

    def forward(self, x):
        out, (h_n, h_c) = self.rnn(x, None)
        return out[:, -1, :]	# Return output at last time-step


X, y = generate_data(IN_SIZE, 5, NUM_SAMPLES)
X = torch.FloatTensor(X)
y = torch.LongTensor(y)

rnn = RNN()
optimizer = torch.optim.Adam(rnn.parameters(), lr=0.001)
loss_func = nn.CrossEntropyLoss()     


In [148]:
item

tensor([[[14., 12., 30., 10.,  5.],
         [15., 11., 40.,  7.,  6.],
         [16., 10., 50.,  6.,  7.],
         [17.,  9., 60.,  5.,  8.],
         [18.,  8., 70.,  4.,  9.]]])

In [145]:
output

tensor([[-0.9949, -0.9869,  0.9963, -0.9894, -0.9928, -0.9960]],
       grad_fn=<SliceBackward>)

In [141]:
for j in range(EPOCHS):
    for i, item in enumerate(X):
        item.shape
        item = item.unsqueeze(0)
        output = rnn(item)
        loss = loss_func(output, y[i])
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

    if j % 5 == 0:
        print('Loss: ', np.average(loss.detach()))


Loss:  0.67127174
Loss:  0.6707352
Loss:  0.67021835
Loss:  0.6697202
Loss:  0.66923946
Loss:  0.6687752
Loss:  0.6683264
Loss:  0.6678922
Loss:  0.66747177
Loss:  0.6670644
Loss:  0.66666967
Loss:  0.6662865
Loss:  0.66591483
Loss:  0.66555375
Loss:  0.665203
Loss:  0.664862
Loss:  0.6645304
Loss:  0.66420776
Loss:  0.6638938
Loss:  0.66358805
Loss:  0.6632902
Loss:  0.663
Loss:  0.6627171
Loss:  0.66244125
Loss:  0.66217214
Loss:  0.66190964
Loss:  0.6616534
Loss:  0.6614034
Loss:  0.6611591
Loss:  0.66092044
Loss:  0.6606872
Loss:  0.6604594
Loss:  0.6602366
Loss:  0.6600187
Loss:  0.6598054
Loss:  0.6595967
Loss:  0.65939224
Loss:  0.65919185
Loss:  0.6589953
Loss:  0.65880215
Loss:  0.6586126
Loss:  0.65842706
Loss:  0.6582459
Loss:  0.65806925
Loss:  0.6578964
Loss:  0.657727
Loss:  0.65756106
Loss:  0.6573982
Loss:  0.6572387
Loss:  0.6570824
Loss:  0.656929
Loss:  0.6567785
Loss:  0.65663105
Loss:  0.6564863
Loss:  0.6563443
Loss:  0.65620494


KeyboardInterrupt: 