In [1]:
import torch 
import torch.nn as nn
import torch.optim as optim
from wgan import FE, Discriminator, Classifier, Wasserstein_Loss, Grad_Loss
from tqdm import tqdm
import os
from scipy import io
import numpy as np
from sklearn.model_selection import train_test_split
from torch.utils.data import DataLoader
from torch.utils.data import TensorDataset

In [2]:
import os
os.getcwd()

'/home/hun/바탕화면/AAI/WGAN'

In [3]:
# path
path = r'../data_preprocessed_matlab/'  # 경로는 저장 파일 경로
file_list = os.listdir(path)

print("data path check")
for i in file_list:    # 확인
    print(i, end=' ')


for i in tqdm(file_list, desc="read data"): 
    mat_file = io.loadmat(path+i)
    data = mat_file['data']
    labels = np.array(mat_file['labels'])
    val = labels.T[0].round().astype(np.int8)
    aro = labels.T[1].round().astype(np.int8)
    
    if(i=="s05.mat"): 
        Data = data
        VAL = val
        ARO = aro
        continue
        
    Data = np.concatenate((Data ,data),axis=0)   # 밑으로 쌓아서 하나로 만듬
    VAL = np.concatenate((VAL ,val),axis=0)
    ARO = np.concatenate((ARO ,aro),axis=0)

data path check
s05.mat s08.mat s17.mat s07.mat s06.mat s27.mat s22.mat s10.mat s25.mat s20.mat s16.mat s15.mat s21.mat s03.mat s02.mat s31.mat s11.mat s12.mat s32.mat s13.mat s29.mat s09.mat s28.mat s01.mat s30.mat s14.mat s19.mat s26.mat s04.mat s18.mat s24.mat s23.mat 

read data: 100%|██████████| 32/32 [00:19<00:00,  1.63it/s]


In [4]:
# eeg preprocessing

eeg_data = []
peripheral_data = []

for i in tqdm(range(len(Data)), desc="preprocess channel"):
    for j in range (40): 
        if(j < 32): # get channels 1 to 32
            eeg_data.append(Data[i][j])
        else:
            peripheral_data.append(Data[i][j])

# set data type, shape
eeg_data = np.reshape(eeg_data, (len(Data),1,32, 8064))
eeg_data = eeg_data.astype('float32')
eeg_data32 = torch.from_numpy(eeg_data)
VAL = (torch.from_numpy(VAL)).type(torch.long)

preprocess channel: 100%|██████████| 1280/1280 [00:00<00:00, 68220.00it/s]


In [5]:
#data 40 x 40 x 8064 video/trial x channel x data
#labels 40 x 4 video/trial x label (valence, arousal, dominance, liking)
#32명 -> 12 / 12 / 8

# data split
print("data split")
train_data, val_data,train_label, val_label = train_test_split(eeg_data32, VAL, test_size=0.25)
x_train, x_test, y_train, y_test = train_test_split(train_data, train_label, test_size=0.5)

# make data loader
print("make data loader")
target_dataset = TensorDataset(x_train, y_train)
source_dataset = TensorDataset(x_test, y_test)
val_dataset = TensorDataset(val_data, val_label)
target_dataloader = DataLoader(target_dataset, 64, shuffle=True)
source_dataloader = DataLoader(source_dataset, 64, shuffle=True)
val_dataloader = DataLoader(val_dataset, 64, shuffle=True)

data split
make data loader


In [6]:
# cuda
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
print("device: ", device)

#model
dis = Discriminator(15960).to(device)
fe = FE(32).to(device)
classifier = Classifier().to(device)

#optim
optimizer_dis = optim.Adam(dis.parameters(),lr=0.0001,betas=(0,0.9))
optimizer_fe = optim.Adam(fe.parameters(),lr=0.0001, betas=(0,0.0))
optimizer_cls = optim.Adam(classifier.parameters(),lr=0.0001, betas=(0,0.9))

#cls_loss
criterion = nn.CrossEntropyLoss().to(device)

device:  cuda:0


In [7]:
# train WGAN
accuracy_s = []
accuracy_t = []
accuracy_val = []

best_loss = 10000000
limit_epoch = 500
limit_check = 0
val_loss = 0
nb_epochs = 1000
lambda_hyper = 10
mu_hyper = 1
n = 5

log = open('Wasserstein_log.txt', 'w')

torch.autograd.set_detect_anomaly(True)

# while parameter converge
for epoch in tqdm(range(nb_epochs)):
    temp_accuracy_t = 0
    temp_accuracy_s = 0
    temp_accuracy_val = 0
    temp_gloss = 0
    temp_wdloss = 0
    temp_gradloss = 0
    temp_clsloss = 0

    print(epoch+1, ": epoch", file=log)

    temp = 0.0 #batch count
    fe.train()
    dis.train()
    classifier.train()
    # batch
    for i, (target, source) in enumerate(zip(target_dataloader, source_dataloader)):
        temp += 1.0

        x_target = target[0].to(device)
        y_target = target[1].to(device)
        x_source = source[0].to(device)
        y_source = source[1].to(device)

        # update discriminator
        for p in fe.parameters() :
            p.requires_grad = False
        for p in dis.parameters() :
            p.requires_grad = True
        for p in classifier.parameters() :
            p.requires_grad = False
        
        for k in range(n) :
            optimizer_dis.zero_grad()
            wd_grad_loss = 0
            feat_t = fe(x_target)
            feat_s = fe(x_source)
            pred_t = classifier(feat_t)
            pred_s = classifier(feat_s)
            for j in range(feat_s.size(0)) :
                epsil = torch.rand(1).item()
                feat = epsil*feat_s[j,:]+(1-epsil)*feat_t[j,:]
                dc_t = dis(feat_t)
                dc_s = dis(feat_s)
                wd_loss = Wasserstein_Loss(dc_s, dc_t)
                grad_loss = Grad_Loss(feat, dis, device)
                wd_grad_loss = wd_grad_loss - (wd_loss-lambda_hyper*grad_loss)
            wd_grad_loss = wd_grad_loss / feat_s.size(0)
            wd_grad_loss.backward()
            optimizer_dis.step()

        # update classifier
        for p in fe.parameters() :
            p.requires_grad = False
        for p in dis.parameters() :
            p.requires_grad = False
        for p in classifier.parameters() :
            p.requires_grad = True
        
        optimizer_cls.zero_grad()
        feat_s = fe(x_source)
        pred_s = classifier(feat_s)
        cls_loss_source = criterion(pred_s, y_source-1)
        cls_loss_source.backward()
        optimizer_cls.step()
        
        # update Feature Extractor
        for p in fe.parameters() :
            p.requires_grad = True
        for p in dis.parameters() :
            p.requires_grad = False
        for p in classifier.parameters() :
            p.requires_grad = False
        
        optimizer_fe.zero_grad()
        feat_t = fe(x_target)
        feat_s = fe(x_source)
        pred_s = classifier(feat_s)
        dc_t = dis(feat_t)
        dc_s = dis(feat_s)
        wd_loss = Wasserstein_Loss(dc_s, dc_t)
        cls_loss_source = criterion(pred_s, y_source-1)
        fe_loss = cls_loss_source + wd_loss
        fe_loss.backward()
        optimizer_fe.step()
        
        # Temp_Loss
        wd_loss = Wasserstein_Loss(dc_s, dc_t)
        cls_loss_source = criterion(pred_s, y_source-1)
        g_loss = cls_loss_source + mu_hyper*(wd_loss - lambda_hyper*grad_loss)

        feat_t = fe(x_target)
        feat_s = fe(x_source)
        pred_t = classifier(feat_t)
        pred_s = classifier(feat_s)
        
        temp_wdloss = temp_wdloss + wd_loss
        temp_clsloss = temp_clsloss + cls_loss_source
        temp_gloss = temp_gloss + g_loss

        temp_accuracy_t += ((torch.argmax(pred_t,1)+1)== y_target).to(torch.float).mean()
        temp_accuracy_s += ((torch.argmax(pred_s,1)+1)== y_source).to(torch.float).mean()
    
    print("\ngloss", temp_gloss.item()/temp, file=log)
    print("wd_loss", temp_wdloss.item()/temp, file=log)
    print("cls_loss", temp_clsloss.item()/temp, file=log)
    print("acc_t", temp_accuracy_t.item()/temp, file=log)
    print("acc_s", temp_accuracy_s.item()/temp, file=log)
    
    accuracy_t.append(temp_accuracy_t/temp)
    accuracy_s.append(temp_accuracy_s/temp)
    
    fe.eval()
    dis.eval()
    classifier.eval()
    val_loss = 0
    temp = 0
    for x_val, y_val in val_dataloader:
        x_val = x_val.to(device)
        y_val = y_val.to(device)
        pred_val = classifier(fe(x_val))
        temp_accuracy_val += ((torch.argmax(pred_val,1)+1)== y_val).to(torch.float).mean()
        loss = criterion(pred_val, y_val-1)
        val_loss += loss.item() * x_val.size(0)
        temp += 1
    val_total_loss = val_loss / len(val_dataloader.dataset)
    print("val_loss :", val_total_loss, file=log)
    print("acc_val :", temp_accuracy_val.item()/temp, file=log)
    accuracy_val.append(temp_accuracy_val.item()/temp)
    if val_total_loss > best_loss:
        limit_check += 1
        if(limit_check >= limit_epoch):
            break
    else:
        best_loss = val_total_loss
        limit_check = 0
    print()

print("accuracy_t ", sum(accuracy_t)/len(accuracy_t), file=log)
print("accuracy_s ", sum(accuracy_s)/len(accuracy_s), file=log)
print("accuracy_val", sum(accuracy_val)/len(accuracy_val), file=log)
print("best_val_loss ", best_loss, file=log)

log.close()

  0%|          | 0/1000 [00:00<?, ?it/s]

1 : epoch


  0%|          | 1/1000 [00:15<4:23:20, 15.82s/it]


gloss -6.219282150268555
wd_loss 0.4094912111759186
cls_loss 2.165621042251587
acc_t 0.044921875
acc_s 0.099609375
val_loss : 2.243357515335083
acc_val : 0.059375

2 : epoch


  0%|          | 2/1000 [00:31<4:17:17, 15.47s/it]


gloss -4.010986804962158
wd_loss 0.8942652940750122
cls_loss 2.1188952922821045
acc_t 0.037109375
acc_s 0.146484375
val_loss : 2.247661018371582
acc_val : 0.059375

3 : epoch


  0%|          | 3/1000 [00:46<4:15:04, 15.35s/it]


gloss -3.928805112838745
wd_loss 1.3024367094039917
cls_loss 2.0865671634674072
acc_t 0.041015625
acc_s 0.189453125
val_loss : 2.2503774642944334
acc_val : 0.059375

4 : epoch


  0%|          | 4/1000 [01:01<4:13:20, 15.26s/it]


gloss -2.205061197280884
wd_loss 1.68748140335083
cls_loss 2.070878744125366
acc_t 0.0390625
acc_s 0.232421875
val_loss : 2.2491321563720703
acc_val : 0.0625

5 : epoch


  0%|          | 5/1000 [01:16<4:12:15, 15.21s/it]


gloss -0.5243695974349976
wd_loss 2.1519391536712646
cls_loss 2.0527660846710205
acc_t 0.048828125
acc_s 0.263671875
val_loss : 2.248411703109741
acc_val : 0.065625

6 : epoch


  1%|          | 6/1000 [01:31<4:11:36, 15.19s/it]


gloss 3.659914493560791
wd_loss 2.915865182876587
cls_loss 2.0374057292938232
acc_t 0.048828125
acc_s 0.3046875
val_loss : 2.245560884475708
acc_val : 0.065625

7 : epoch


  1%|          | 7/1000 [01:46<4:10:47, 15.15s/it]


gloss 3.386341094970703
wd_loss 3.741833209991455
cls_loss 2.01908278465271
acc_t 0.046875
acc_s 0.3359375
val_loss : 2.244804048538208
acc_val : 0.065625

8 : epoch


  1%|          | 8/1000 [02:01<4:10:21, 15.14s/it]


gloss 0.7116458415985107
wd_loss 4.3877458572387695
cls_loss 2.004072904586792
acc_t 0.044921875
acc_s 0.361328125
val_loss : 2.2422183990478515
acc_val : 0.065625

9 : epoch


  1%|          | 9/1000 [02:16<4:09:59, 15.14s/it]


gloss 0.15572547912597656
wd_loss 5.005131244659424
cls_loss 2.0047788619995117
acc_t 0.0390625
acc_s 0.3671875
val_loss : 2.2410067558288573
acc_val : 0.065625

10 : epoch


  1%|          | 10/1000 [02:32<4:09:52, 15.14s/it]


gloss -18.1685791015625
wd_loss 5.700887680053711
cls_loss 1.995687484741211
acc_t 0.048828125
acc_s 0.38671875
val_loss : 2.240198564529419
acc_val : 0.071875

11 : epoch


  1%|          | 11/1000 [02:47<4:09:41, 15.15s/it]


gloss -64.01151275634766
wd_loss 6.921267986297607
cls_loss 1.9860751628875732
acc_t 0.046875
acc_s 0.390625
val_loss : 2.2386740684509276
acc_val : 0.075

12 : epoch


  1%|          | 12/1000 [03:02<4:09:37, 15.16s/it]


gloss -135.78665161132812
wd_loss 8.264023780822754
cls_loss 1.9795933961868286
acc_t 0.05078125
acc_s 0.40234375
val_loss : 2.2370927810668944
acc_val : 0.071875

13 : epoch


  1%|▏         | 13/1000 [03:17<4:09:16, 15.15s/it]


gloss -66.21538543701172
wd_loss 9.365934371948242
cls_loss 1.9623825550079346
acc_t 0.0625
acc_s 0.41015625
val_loss : 2.238275098800659
acc_val : 0.071875

14 : epoch


  1%|▏         | 14/1000 [03:32<4:08:52, 15.14s/it]


gloss -192.61387634277344
wd_loss 10.095184326171875
cls_loss 1.9571596384048462
acc_t 0.052734375
acc_s 0.41796875
val_loss : 2.2394955158233643
acc_val : 0.06875

15 : epoch


  2%|▏         | 15/1000 [03:47<4:08:33, 15.14s/it]


gloss -426.3749694824219
wd_loss 12.088346481323242
cls_loss 1.9428645372390747
acc_t 0.060546875
acc_s 0.44140625
val_loss : 2.237738561630249
acc_val : 0.084375

16 : epoch


  2%|▏         | 16/1000 [04:03<4:08:20, 15.14s/it]


gloss -386.57916259765625
wd_loss 13.569659233093262
cls_loss 1.934916615486145
acc_t 0.0546875
acc_s 0.443359375
val_loss : 2.2356035709381104
acc_val : 0.08125

17 : epoch


  2%|▏         | 17/1000 [04:18<4:08:04, 15.14s/it]


gloss -1221.398193359375
wd_loss 14.715400695800781
cls_loss 1.9325497150421143
acc_t 0.060546875
acc_s 0.451171875
val_loss : 2.234473705291748
acc_val : 0.071875

18 : epoch


  2%|▏         | 18/1000 [04:33<4:08:01, 15.15s/it]


gloss -885.1934814453125
wd_loss 16.26268768310547
cls_loss 1.9263108968734741
acc_t 0.060546875
acc_s 0.46875
val_loss : 2.2380990028381347
acc_val : 0.078125

19 : epoch


  2%|▏         | 19/1000 [04:48<4:07:37, 15.15s/it]


gloss -2408.59716796875
wd_loss 18.80521583557129
cls_loss 1.9243208169937134
acc_t 0.060546875
acc_s 0.455078125
val_loss : 2.2381691455841066
acc_val : 0.08125

20 : epoch


  2%|▏         | 20/1000 [05:03<4:07:27, 15.15s/it]


gloss -1912.6168212890625
wd_loss 20.589574813842773
cls_loss 1.9142754077911377
acc_t 0.05859375
acc_s 0.474609375
val_loss : 2.2360536575317385
acc_val : 0.08125

21 : epoch


  2%|▏         | 21/1000 [05:18<4:07:15, 15.15s/it]


gloss -2618.86865234375
wd_loss 22.716209411621094
cls_loss 1.9124232530593872
acc_t 0.0546875
acc_s 0.474609375
val_loss : 2.2359854221343993
acc_val : 0.090625

22 : epoch


  2%|▏         | 22/1000 [05:33<4:06:54, 15.15s/it]


gloss -5590.71533203125
wd_loss 25.617761611938477
cls_loss 1.9059383869171143
acc_t 0.05859375
acc_s 0.48046875
val_loss : 2.235266399383545
acc_val : 0.084375

23 : epoch


  2%|▏         | 23/1000 [05:49<4:06:35, 15.14s/it]


gloss -6633.8837890625
wd_loss 26.498554229736328
cls_loss 1.9032138586044312
acc_t 0.07421875
acc_s 0.486328125
val_loss : 2.231446123123169
acc_val : 0.075

24 : epoch


  2%|▏         | 24/1000 [06:04<4:06:21, 15.15s/it]


gloss -11640.359375
wd_loss 29.9429874420166
cls_loss 1.906314730644226
acc_t 0.056640625
acc_s 0.47265625
val_loss : 2.2344152450561525
acc_val : 0.078125

25 : epoch


  2%|▎         | 25/1000 [06:19<4:06:01, 15.14s/it]


gloss -7354.5498046875
wd_loss 31.736339569091797
cls_loss 1.89288330078125
acc_t 0.068359375
acc_s 0.48046875
val_loss : 2.233854293823242
acc_val : 0.096875

26 : epoch


  3%|▎         | 26/1000 [06:34<4:05:49, 15.14s/it]


gloss -13343.34765625
wd_loss 34.55559158325195
cls_loss 1.8866894245147705
acc_t 0.083984375
acc_s 0.490234375
val_loss : 2.2336291313171386
acc_val : 0.10625

27 : epoch


  3%|▎         | 27/1000 [06:49<4:05:31, 15.14s/it]


gloss -25819.84375
wd_loss 37.84110641479492
cls_loss 1.8800859451293945
acc_t 0.064453125
acc_s 0.482421875
val_loss : 2.2382907390594484
acc_val : 0.084375

28 : epoch


  3%|▎         | 28/1000 [07:04<4:05:26, 15.15s/it]


gloss -30482.388671875
wd_loss 41.05496597290039
cls_loss 1.883509635925293
acc_t 0.07421875
acc_s 0.4921875
val_loss : 2.2370689868927003
acc_val : 0.0875

29 : epoch


  3%|▎         | 29/1000 [07:19<4:05:04, 15.14s/it]


gloss -35113.71875
wd_loss 44.11396789550781
cls_loss 1.8678442239761353
acc_t 0.080078125
acc_s 0.513671875
val_loss : 2.235229825973511
acc_val : 0.0875

30 : epoch


  3%|▎         | 30/1000 [07:35<4:04:53, 15.15s/it]


gloss -38639.66796875
wd_loss 46.795936584472656
cls_loss 1.8696054220199585
acc_t 0.078125
acc_s 0.501953125
val_loss : 2.2362411499023436
acc_val : 0.084375

31 : epoch


  3%|▎         | 31/1000 [07:50<4:04:32, 15.14s/it]


gloss -47560.69921875
wd_loss 51.80742645263672
cls_loss 1.8639473915100098
acc_t 0.0703125
acc_s 0.501953125
val_loss : 2.2339430332183836
acc_val : 0.084375

32 : epoch


  3%|▎         | 32/1000 [08:05<4:04:25, 15.15s/it]


gloss -78083.421875
wd_loss 53.45508575439453
cls_loss 1.8589565753936768
acc_t 0.0859375
acc_s 0.49609375
val_loss : 2.23126482963562
acc_val : 0.090625

33 : epoch


  3%|▎         | 33/1000 [08:20<4:04:05, 15.15s/it]


gloss -61901.28515625
wd_loss 58.161415100097656
cls_loss 1.8465087413787842
acc_t 0.087890625
acc_s 0.51953125
val_loss : 2.229951190948486
acc_val : 0.0875

34 : epoch


  3%|▎         | 34/1000 [08:35<4:03:32, 15.13s/it]


gloss -82558.9609375
wd_loss 60.52728271484375
cls_loss 1.8396165370941162
acc_t 0.076171875
acc_s 0.501953125
val_loss : 2.2302874088287354
acc_val : 0.0875

35 : epoch


  4%|▎         | 35/1000 [08:50<4:03:18, 15.13s/it]


gloss -121722.6875
wd_loss 64.49632263183594
cls_loss 1.8459086418151855
acc_t 0.083984375
acc_s 0.51171875
val_loss : 2.2303934574127195
acc_val : 0.078125

36 : epoch


  4%|▎         | 36/1000 [09:05<4:03:11, 15.14s/it]


gloss -133050.578125
wd_loss 68.568603515625
cls_loss 1.8426494598388672
acc_t 0.091796875
acc_s 0.52734375
val_loss : 2.226737880706787
acc_val : 0.096875

37 : epoch


  4%|▎         | 37/1000 [09:20<4:02:44, 15.12s/it]


gloss -184985.8125
wd_loss 72.33611297607422
cls_loss 1.8202409744262695
acc_t 0.103515625
acc_s 0.576171875
val_loss : 2.224970293045044
acc_val : 0.11875

38 : epoch


  4%|▍         | 38/1000 [09:36<4:02:31, 15.13s/it]


gloss -157378.375
wd_loss 75.4693832397461
cls_loss 1.8209388256072998
acc_t 0.103515625
acc_s 0.552734375
val_loss : 2.2314943790435793
acc_val : 0.090625

39 : epoch


  4%|▍         | 39/1000 [09:51<4:02:15, 15.13s/it]


gloss -186627.90625
wd_loss 79.16190338134766
cls_loss 1.8468639850616455
acc_t 0.08203125
acc_s 0.509765625
val_loss : 2.235993528366089
acc_val : 0.084375

40 : epoch


  4%|▍         | 40/1000 [10:06<4:01:51, 15.12s/it]


gloss -214705.953125
wd_loss 81.70838928222656
cls_loss 1.8437596559524536
acc_t 0.087890625
acc_s 0.515625
val_loss : 2.2338488578796385
acc_val : 0.09375

41 : epoch


  4%|▍         | 41/1000 [10:21<4:01:35, 15.12s/it]


gloss -285317.40625
wd_loss 84.15269470214844
cls_loss 1.828176736831665
acc_t 0.083984375
acc_s 0.578125
val_loss : 2.232428026199341
acc_val : 0.096875

42 : epoch


  4%|▍         | 42/1000 [10:36<4:01:12, 15.11s/it]


gloss -288974.6875
wd_loss 91.84725952148438
cls_loss 1.7958301305770874
acc_t 0.115234375
acc_s 0.607421875
val_loss : 2.2366692066192626
acc_val : 0.090625

43 : epoch


  4%|▍         | 43/1000 [10:51<4:00:58, 15.11s/it]


gloss -263575.15625
wd_loss 94.89556884765625
cls_loss 1.7969292402267456
acc_t 0.08984375
acc_s 0.576171875
val_loss : 2.2414605617523193
acc_val : 0.084375

44 : epoch


  4%|▍         | 44/1000 [11:06<4:00:53, 15.12s/it]


gloss -400696.03125
wd_loss 98.05339050292969
cls_loss 1.8316813707351685
acc_t 0.072265625
acc_s 0.52734375
val_loss : 2.2382522583007813
acc_val : 0.09375

45 : epoch


  4%|▍         | 45/1000 [11:21<4:00:49, 15.13s/it]


gloss -648477.0
wd_loss 105.38314819335938
cls_loss 1.8124423027038574
acc_t 0.078125
acc_s 0.55078125
val_loss : 2.2321247100830077
acc_val : 0.08125

46 : epoch


  5%|▍         | 46/1000 [11:37<4:00:28, 15.12s/it]


gloss -396589.0
wd_loss 108.26069641113281
cls_loss 1.8043084144592285
acc_t 0.095703125
acc_s 0.595703125
val_loss : 2.2287286281585694
acc_val : 0.109375

47 : epoch


  5%|▍         | 47/1000 [11:52<4:00:12, 15.12s/it]


gloss -878574.125
wd_loss 114.04165649414062
cls_loss 1.7753745317459106
acc_t 0.10546875
acc_s 0.591796875
val_loss : 2.232746934890747
acc_val : 0.103125

48 : epoch


  5%|▍         | 48/1000 [12:07<3:59:58, 15.12s/it]


gloss -1003248.5625
wd_loss 118.330810546875
cls_loss 1.7873759269714355
acc_t 0.08984375
acc_s 0.572265625
val_loss : 2.231401538848877
acc_val : 0.103125

49 : epoch


  5%|▍         | 49/1000 [12:22<3:59:39, 15.12s/it]


gloss -606913.3125
wd_loss 122.7668228149414
cls_loss 1.7977852821350098
acc_t 0.0859375
acc_s 0.58203125
val_loss : 2.2330856800079344
acc_val : 0.09375

50 : epoch


  5%|▌         | 50/1000 [12:37<3:59:23, 15.12s/it]


gloss -1178546.5
wd_loss 130.3976287841797
cls_loss 1.787593126296997
acc_t 0.080078125
acc_s 0.5859375
val_loss : 2.235830068588257
acc_val : 0.096875

51 : epoch


  5%|▌         | 51/1000 [12:52<3:59:20, 15.13s/it]


gloss -1167830.375
wd_loss 128.57473754882812
cls_loss 1.78971266746521
acc_t 0.091796875
acc_s 0.583984375
val_loss : 2.239219331741333
acc_val : 0.09375

52 : epoch


  5%|▌         | 52/1000 [13:07<3:58:56, 15.12s/it]


gloss -986611.8125
wd_loss 138.199951171875
cls_loss 1.7730190753936768
acc_t 0.115234375
acc_s 0.59375
val_loss : 2.236942005157471
acc_val : 0.1

53 : epoch


  5%|▌         | 53/1000 [13:22<3:58:35, 15.12s/it]


gloss -1087450.75
wd_loss 144.49276733398438
cls_loss 1.768756628036499
acc_t 0.10546875
acc_s 0.5859375
val_loss : 2.2344104290008544
acc_val : 0.10625

54 : epoch


  5%|▌         | 54/1000 [13:37<3:58:17, 15.11s/it]


gloss -1348959.625
wd_loss 147.3529052734375
cls_loss 1.7818659543991089
acc_t 0.08984375
acc_s 0.576171875
val_loss : 2.2341258049011232
acc_val : 0.10625

55 : epoch


  6%|▌         | 55/1000 [13:53<3:57:59, 15.11s/it]


gloss -2444905.75
wd_loss 160.0980682373047
cls_loss 1.7612905502319336
acc_t 0.09765625
acc_s 0.59765625
val_loss : 2.233092451095581
acc_val : 0.115625

56 : epoch


  6%|▌         | 56/1000 [14:08<3:57:46, 15.11s/it]


gloss -1946826.5
wd_loss 161.70260620117188
cls_loss 1.7612553834915161
acc_t 0.119140625
acc_s 0.58984375
val_loss : 2.228915882110596
acc_val : 0.140625

57 : epoch


  6%|▌         | 57/1000 [14:23<3:57:30, 15.11s/it]


gloss -2138193.25
wd_loss 168.53280639648438
cls_loss 1.7421059608459473
acc_t 0.134765625
acc_s 0.615234375
val_loss : 2.2231271266937256
acc_val : 0.128125

58 : epoch


  6%|▌         | 58/1000 [14:38<3:57:22, 15.12s/it]


gloss -2890705.5
wd_loss 172.07089233398438
cls_loss 1.7379209995269775
acc_t 0.125
acc_s 0.609375
val_loss : 2.224172019958496
acc_val : 0.13125

59 : epoch


  6%|▌         | 59/1000 [14:53<3:57:15, 15.13s/it]


gloss -3935441.25
wd_loss 182.43386840820312
cls_loss 1.7328921556472778
acc_t 0.1171875
acc_s 0.61328125
val_loss : 2.22367787361145
acc_val : 0.1125

60 : epoch


  6%|▌         | 60/1000 [15:08<3:57:00, 15.13s/it]


gloss -3806654.75
wd_loss 187.77000427246094
cls_loss 1.74979567527771
acc_t 0.109375
acc_s 0.59765625
val_loss : 2.2275405406951903
acc_val : 0.1125

61 : epoch


  6%|▌         | 61/1000 [15:23<3:57:16, 15.16s/it]


gloss -4795406.5
wd_loss 196.85678100585938
cls_loss 1.729100227355957
acc_t 0.115234375
acc_s 0.62109375
val_loss : 2.2274269580841066
acc_val : 0.125

62 : epoch


  6%|▌         | 62/1000 [15:39<3:57:26, 15.19s/it]


gloss -2600251.75
wd_loss 207.16580200195312
cls_loss 1.6968010663986206
acc_t 0.115234375
acc_s 0.66015625
val_loss : 2.2264544010162353
acc_val : 0.128125

63 : epoch


  6%|▋         | 63/1000 [15:54<3:57:02, 15.18s/it]


gloss -5198238.5
wd_loss 211.4051055908203
cls_loss 1.6799423694610596
acc_t 0.13671875
acc_s 0.662109375
val_loss : 2.219622564315796
acc_val : 0.146875

64 : epoch


  6%|▋         | 64/1000 [16:09<3:56:45, 15.18s/it]


gloss -3813347.75
wd_loss 220.2966766357422
cls_loss 1.680465579032898
acc_t 0.13671875
acc_s 0.666015625
val_loss : 2.212794876098633
acc_val : 0.14375

65 : epoch


  6%|▋         | 65/1000 [16:24<3:56:31, 15.18s/it]


gloss -4910021.0
wd_loss 230.5958251953125
cls_loss 1.7053824663162231
acc_t 0.130859375
acc_s 0.640625
val_loss : 2.2149919509887694
acc_val : 0.128125

66 : epoch


  7%|▋         | 66/1000 [16:39<3:56:08, 15.17s/it]


gloss -6431249.5
wd_loss 233.7766876220703
cls_loss 1.718130350112915
acc_t 0.12109375
acc_s 0.623046875
val_loss : 2.2142574787139893
acc_val : 0.134375

67 : epoch


  7%|▋         | 67/1000 [16:55<3:55:38, 15.15s/it]


gloss -6097180.5
wd_loss 239.28439331054688
cls_loss 1.6794805526733398
acc_t 0.12109375
acc_s 0.654296875
val_loss : 2.2062800884246827
acc_val : 0.15625

68 : epoch


  7%|▋         | 68/1000 [17:10<3:55:26, 15.16s/it]


gloss -8155503.5
wd_loss 242.97987365722656
cls_loss 1.6488499641418457
acc_t 0.140625
acc_s 0.666015625
val_loss : 2.1957441329956056
acc_val : 0.15625

69 : epoch


  7%|▋         | 69/1000 [17:25<3:55:07, 15.15s/it]


gloss -5865363.0
wd_loss 255.46734619140625
cls_loss 1.6291842460632324
acc_t 0.140625
acc_s 0.671875
val_loss : 2.194205570220947
acc_val : 0.159375

70 : epoch


  7%|▋         | 70/1000 [17:40<3:54:50, 15.15s/it]


gloss -10149298.0
wd_loss 257.9111328125
cls_loss 1.6786028146743774
acc_t 0.134765625
acc_s 0.6328125
val_loss : 2.1963059425354006
acc_val : 0.14375

71 : epoch


  7%|▋         | 71/1000 [17:55<3:54:43, 15.16s/it]


gloss -10219080.0
wd_loss 271.9619445800781
cls_loss 1.6839675903320312
acc_t 0.1328125
acc_s 0.646484375
val_loss : 2.1991507053375243
acc_val : 0.15625

72 : epoch


  7%|▋         | 72/1000 [18:10<3:54:13, 15.14s/it]


gloss -12053357.0
wd_loss 283.6622619628906
cls_loss 1.671107530593872
acc_t 0.14453125
acc_s 0.64453125
val_loss : 2.207352066040039
acc_val : 0.134375

73 : epoch


  7%|▋         | 73/1000 [18:25<3:53:58, 15.14s/it]


gloss -16571074.0
wd_loss 282.1472473144531
cls_loss 1.6633236408233643
acc_t 0.125
acc_s 0.701171875
val_loss : 2.2140777111053467
acc_val : 0.134375

74 : epoch


  7%|▋         | 74/1000 [18:41<3:53:59, 15.16s/it]


gloss -12877227.0
wd_loss 298.00360107421875
cls_loss 1.6426513195037842
acc_t 0.13671875
acc_s 0.693359375
val_loss : 2.21984486579895
acc_val : 0.1375

75 : epoch


  8%|▊         | 75/1000 [18:56<3:53:41, 15.16s/it]


gloss -13268559.0
wd_loss 326.8150329589844
cls_loss 1.598499059677124
acc_t 0.13671875
acc_s 0.701171875
val_loss : 2.232092571258545
acc_val : 0.146875

76 : epoch


  8%|▊         | 76/1000 [19:11<3:53:11, 15.14s/it]


gloss -11616535.0
wd_loss 325.8056640625
cls_loss 1.6114771366119385
acc_t 0.130859375
acc_s 0.65625
val_loss : 2.233233118057251
acc_val : 0.153125

77 : epoch


  8%|▊         | 77/1000 [19:26<3:53:02, 15.15s/it]


gloss -6299918.0
wd_loss 332.07745361328125
cls_loss 1.6227055788040161
acc_t 0.142578125
acc_s 0.65625
val_loss : 2.2310932159423826
acc_val : 0.15

78 : epoch


  8%|▊         | 78/1000 [19:41<3:52:54, 15.16s/it]


gloss -20081092.0
wd_loss 348.0567626953125
cls_loss 1.6108171939849854
acc_t 0.158203125
acc_s 0.65234375
val_loss : 2.2275837898254394
acc_val : 0.140625

79 : epoch


  8%|▊         | 79/1000 [19:56<3:53:04, 15.18s/it]


gloss -23408672.0
wd_loss 357.96435546875
cls_loss 1.631385087966919
acc_t 0.140625
acc_s 0.66796875
val_loss : 2.2334771156311035
acc_val : 0.125

80 : epoch


  8%|▊         | 80/1000 [20:12<3:53:15, 15.21s/it]


gloss -23947384.0
wd_loss 363.5655517578125
cls_loss 1.6522010564804077
acc_t 0.107421875
acc_s 0.69140625
val_loss : 2.239365243911743
acc_val : 0.125

81 : epoch


  8%|▊         | 81/1000 [20:27<3:52:47, 15.20s/it]


gloss -23482720.0
wd_loss 373.9883728027344
cls_loss 1.6074988842010498
acc_t 0.115234375
acc_s 0.705078125
val_loss : 2.249380922317505
acc_val : 0.11875

82 : epoch


  8%|▊         | 82/1000 [20:42<3:52:21, 15.19s/it]


gloss -26500280.0
wd_loss 374.4302062988281
cls_loss 1.5852597951889038
acc_t 0.1328125
acc_s 0.697265625
val_loss : 2.2600011348724367
acc_val : 0.1125

83 : epoch


  8%|▊         | 83/1000 [20:57<3:51:51, 15.17s/it]


gloss -27449256.0
wd_loss 378.0595703125
cls_loss 1.5476576089859009
acc_t 0.138671875
acc_s 0.712890625
val_loss : 2.258493518829346
acc_val : 0.13125

84 : epoch


  8%|▊         | 84/1000 [21:12<3:51:26, 15.16s/it]


gloss -24052476.0
wd_loss 390.9508361816406
cls_loss 1.5478981733322144
acc_t 0.150390625
acc_s 0.677734375
val_loss : 2.2493897914886474
acc_val : 0.13125

85 : epoch


  8%|▊         | 85/1000 [21:27<3:51:18, 15.17s/it]


gloss -21091104.0
wd_loss 400.331787109375
cls_loss 1.5623695850372314
acc_t 0.162109375
acc_s 0.681640625
val_loss : 2.246457052230835
acc_val : 0.15

86 : epoch


  9%|▊         | 86/1000 [21:43<3:50:55, 15.16s/it]


gloss -24141086.0
wd_loss 408.769287109375
cls_loss 1.615649938583374
acc_t 0.1640625
acc_s 0.640625
val_loss : 2.245483922958374
acc_val : 0.134375

87 : epoch


  9%|▊         | 87/1000 [21:58<3:50:30, 15.15s/it]


gloss -29958796.0
wd_loss 425.63427734375
cls_loss 1.5924851894378662
acc_t 0.1640625
acc_s 0.65625
val_loss : 2.244818019866943
acc_val : 0.134375

88 : epoch


  9%|▉         | 88/1000 [22:13<3:50:09, 15.14s/it]


gloss -44450160.0
wd_loss 455.81402587890625
cls_loss 1.5563377141952515
acc_t 0.142578125
acc_s 0.69140625
val_loss : 2.2460869789123534
acc_val : 0.125

89 : epoch


  9%|▉         | 89/1000 [22:28<3:49:52, 15.14s/it]


gloss -49956152.0
wd_loss 457.658935546875
cls_loss 1.5301189422607422
acc_t 0.142578125
acc_s 0.705078125
val_loss : 2.248707580566406
acc_val : 0.153125

90 : epoch


  9%|▉         | 90/1000 [22:43<3:49:58, 15.16s/it]


gloss -23629508.0
wd_loss 464.323974609375
cls_loss 1.5116007328033447
acc_t 0.1640625
acc_s 0.7109375
val_loss : 2.2527072429656982
acc_val : 0.184375

91 : epoch


  9%|▉         | 91/1000 [22:58<3:49:48, 15.17s/it]


gloss -44737948.0
wd_loss 496.2341613769531
cls_loss 1.4884076118469238
acc_t 0.15625
acc_s 0.732421875
val_loss : 2.2468692779541017
acc_val : 0.1625

92 : epoch


  9%|▉         | 92/1000 [23:14<3:49:32, 15.17s/it]


gloss -70774016.0
wd_loss 510.5567626953125
cls_loss 1.506098985671997
acc_t 0.1640625
acc_s 0.693359375
val_loss : 2.255035400390625
acc_val : 0.14375

93 : epoch


  9%|▉         | 93/1000 [23:29<3:49:12, 15.16s/it]


gloss -45730900.0
wd_loss 526.5377197265625
cls_loss 1.5451674461364746
acc_t 0.158203125
acc_s 0.642578125
val_loss : 2.254454517364502
acc_val : 0.15

94 : epoch


  9%|▉         | 94/1000 [23:44<3:48:48, 15.15s/it]


gloss -61944412.0
wd_loss 546.1400756835938
cls_loss 1.529097080230713
acc_t 0.1796875
acc_s 0.63671875
val_loss : 2.241388511657715
acc_val : 0.15

95 : epoch


 10%|▉         | 95/1000 [23:59<3:48:28, 15.15s/it]


gloss -51304668.0
wd_loss 558.46142578125
cls_loss 1.5119590759277344
acc_t 0.173828125
acc_s 0.69921875
val_loss : 2.240837049484253
acc_val : 0.15625

96 : epoch


 10%|▉         | 96/1000 [24:14<3:48:14, 15.15s/it]


gloss -77284528.0
wd_loss 569.763427734375
cls_loss 1.4991538524627686
acc_t 0.15234375
acc_s 0.736328125
val_loss : 2.2508286476135253
acc_val : 0.140625

97 : epoch


 10%|▉         | 97/1000 [24:29<3:48:06, 15.16s/it]


gloss -83248816.0
wd_loss 585.6597900390625
cls_loss 1.4993464946746826
acc_t 0.1484375
acc_s 0.73828125
val_loss : 2.2487314224243162
acc_val : 0.121875

98 : epoch


 10%|▉         | 98/1000 [24:44<3:47:52, 15.16s/it]


gloss -86153584.0
wd_loss 586.5128784179688
cls_loss 1.4913039207458496
acc_t 0.15625
acc_s 0.728515625
val_loss : 2.2550512313842774
acc_val : 0.1375

99 : epoch


 10%|▉         | 99/1000 [25:00<3:47:24, 15.14s/it]


gloss -94597200.0
wd_loss 593.1232299804688
cls_loss 1.5016921758651733
acc_t 0.15234375
acc_s 0.6640625
val_loss : 2.2610371112823486
acc_val : 0.140625

100 : epoch


 10%|█         | 100/1000 [25:15<3:47:06, 15.14s/it]


gloss -118152504.0
wd_loss 613.6976318359375
cls_loss 1.4858351945877075
acc_t 0.130859375
acc_s 0.66015625
val_loss : 2.258763837814331
acc_val : 0.134375

101 : epoch


 10%|█         | 101/1000 [25:30<3:47:01, 15.15s/it]


gloss -111571528.0
wd_loss 632.843505859375
cls_loss 1.4689841270446777
acc_t 0.130859375
acc_s 0.671875
val_loss : 2.252983236312866
acc_val : 0.146875

102 : epoch


 10%|█         | 102/1000 [25:45<3:46:44, 15.15s/it]


gloss -94006016.0
wd_loss 649.6484985351562
cls_loss 1.480655312538147
acc_t 0.13671875
acc_s 0.701171875
val_loss : 2.2543187618255613
acc_val : 0.146875

103 : epoch


 10%|█         | 102/1000 [26:00<3:49:00, 15.30s/it]


gloss -133394432.0
wd_loss 643.0953369140625
cls_loss 1.478195309638977
acc_t 0.125
acc_s 0.70703125





KeyboardInterrupt: 