# generate data

In [3]:
import scipy
from scipy.special import expit, logit
import numpy as np
import math
# rs = np.random.RandomState(1978)

In [309]:
def generate_data_uniform(p, n, min_b, mean_b, meanX, sdX, train_prop = 1.0):
    
    # generate uniform b's
    low = min_b
    high = (2*mean_b) - min_b
    b = np.random.uniform(low, high, size=(p,1))
    # normal X's
    X = torch.from_numpy(np.random.normal(meanX, sdX, size=(n, p))).float()
    
    # logistic Y's
    Y = torch.from_numpy(np.random.binomial(1, expit( np.matmul(X,b)))).float()
    
    # split into training and test
    X_train = X
    Y_train = Y
    X_test = None
    Y_test = None
    
    if train_prop < 1.0:
        xy = torch.cat((Y,X),axis=1)        
        train_cutoff = int(n * train_prop)
        train = xy[:train_cutoff,:]
        Y_train, X_train =  train[:,0], train[:,1:]
        
        test = xy[train_cutoff:,:]
        Y_test, X_test = test[:,0], test[:, 1:]
    
    
    return Y_train, X_train, Y_test, X_test

In [378]:
p = 50

# source params
n_source = 10000
min_b_source = 0.5
mean_b_source = 2.0
meanX_source = 0.0
sdX_source = 1.0
train_prop_source = 0.80

# target params
n_target = 100
min_b_target = 0.5
mean_b_target = 2.0
meanX_target = 0.0
sdX_target = 1.0
train_prop_target = 0.80

In [379]:
Y_source_train, X_source_train, Y_source_test, X_source_test = generate_data_uniform(p = p,
                                                n = n_source,
                                                 min_b = min_b_source,
                                                 mean_b = mean_b_source,
                                                 meanX =meanX_source,
                                                 sdX = sdX_source,
                                                 train_prop = train_prop_source)
                                                         
    
Y_target_train, X_target_train, Y_target_test, X_target_test = generate_data_uniform(p = p,
                                                                n = n_target,
                                                                 min_b = min_b_target,
                                                                 mean_b = mean_b_target,
                                                                 meanX =meanX_target,
                                                                 sdX = sdX_target,
                                                                 train_prop = train_prop_target)

# Pytorch NN and data setup

In [380]:
from torch import nn
from torch.nn import functional as F
import torch
from sklearn.metrics import roc_auc_score

In [381]:
class BinaryClassification(nn.Module):
    def __init__(self, input_shape):
        super(BinaryClassification, self).__init__()
        self.layer_1 = nn.Linear(input_shape, 10) 
        self.layer_out = nn.Linear(10, 1) 

        
    def forward(self, inputs):
        x = torch.relu(self.layer_1(inputs))
        x = torch.relu(self.layer_out(x))
        
        return x

# Train source model

In [382]:
learning_rate = 0.01
epochs = 1400
# Model , Optimizer, Loss
source_model = BinaryClassification(input_shape=X_source_train.shape[1])
optimizer = torch.optim.SGD(source_model.parameters(),lr=learning_rate)
loss_fn = nn.BCEWithLogitsLoss()

In [383]:
losses = []
accur = []
for i in range(epochs):

    #calculate output
    output = source_model(X_source_train)

    #calculate loss
    loss = loss_fn(output,Y_source_train.reshape(-1,1))

    #backprop
    optimizer.zero_grad()
    loss.backward()
    optimizer.step()

    losses.append(loss)
    print(f"epoch {i}\tloss :{loss}")


epoch 0	loss :0.7060213088989258
epoch 1	loss :0.7058761119842529
epoch 2	loss :0.7057316899299622
epoch 3	loss :0.7055885195732117
epoch 4	loss :0.7054462432861328
epoch 5	loss :0.7053045630455017
epoch 6	loss :0.7051635980606079
epoch 7	loss :0.7050227522850037
epoch 8	loss :0.704882800579071
epoch 9	loss :0.7047437429428101
epoch 10	loss :0.7046060562133789
epoch 11	loss :0.7044693827629089
epoch 12	loss :0.7043334245681763
epoch 13	loss :0.7041980028152466
epoch 14	loss :0.7040631175041199
epoch 15	loss :0.7039293050765991
epoch 16	loss :0.7037969827651978
epoch 17	loss :0.703665018081665
epoch 18	loss :0.7035340666770935
epoch 19	loss :0.7034037709236145
epoch 20	loss :0.7032743096351624
epoch 21	loss :0.7031458616256714
epoch 22	loss :0.7030181884765625
epoch 23	loss :0.7028914093971252
epoch 24	loss :0.7027653455734253
epoch 25	loss :0.702640175819397
epoch 26	loss :0.7025167346000671
epoch 27	loss :0.7023950815200806
epoch 28	loss :0.7022741436958313
epoch 29	loss :0.7021533846

epoch 252	loss :0.6858999133110046
epoch 253	loss :0.6858391761779785
epoch 254	loss :0.6857782602310181
epoch 255	loss :0.6857174038887024
epoch 256	loss :0.6856564283370972
epoch 257	loss :0.6855952143669128
epoch 258	loss :0.6855334043502808
epoch 259	loss :0.6854712963104248
epoch 260	loss :0.685408890247345
epoch 261	loss :0.685346245765686
epoch 262	loss :0.685283362865448
epoch 263	loss :0.6852203011512756
epoch 264	loss :0.6851571202278137
epoch 265	loss :0.6850938200950623
epoch 266	loss :0.6850301027297974
epoch 267	loss :0.6849663853645325
epoch 268	loss :0.684902548789978
epoch 269	loss :0.6848385334014893
epoch 270	loss :0.6847743988037109
epoch 271	loss :0.6847097277641296
epoch 272	loss :0.6846446394920349
epoch 273	loss :0.6845794320106506
epoch 274	loss :0.6845138669013977
epoch 275	loss :0.6844482421875
epoch 276	loss :0.684382438659668
epoch 277	loss :0.6843164563179016
epoch 278	loss :0.6842501759529114
epoch 279	loss :0.6841838359832764
epoch 280	loss :0.6841174364

epoch 506	loss :0.6614860892295837
epoch 507	loss :0.6613467931747437
epoch 508	loss :0.6612074375152588
epoch 509	loss :0.6610677242279053
epoch 510	loss :0.6609269976615906
epoch 511	loss :0.6607856154441833
epoch 512	loss :0.6606439352035522
epoch 513	loss :0.6605017781257629
epoch 514	loss :0.6603594422340393
epoch 515	loss :0.6602168083190918
epoch 516	loss :0.6600739359855652
epoch 517	loss :0.6599308848381042
epoch 518	loss :0.6597875952720642
epoch 519	loss :0.6596438884735107
epoch 520	loss :0.6595001220703125
epoch 521	loss :0.6593560576438904
epoch 522	loss :0.6592115759849548
epoch 523	loss :0.659066915512085
epoch 524	loss :0.6589220762252808
epoch 525	loss :0.6587768793106079
epoch 526	loss :0.6586312651634216
epoch 527	loss :0.6584855318069458
epoch 528	loss :0.6583396196365356
epoch 529	loss :0.6581932306289673
epoch 530	loss :0.6580469012260437
epoch 531	loss :0.6579005718231201
epoch 532	loss :0.6577539443969727
epoch 533	loss :0.6576070785522461
epoch 534	loss :0.657

epoch 764	loss :0.6171359419822693
epoch 765	loss :0.6169339418411255
epoch 766	loss :0.6167318224906921
epoch 767	loss :0.6165294647216797
epoch 768	loss :0.6163268089294434
epoch 769	loss :0.6161238551139832
epoch 770	loss :0.6159208416938782
epoch 771	loss :0.6157175302505493
epoch 772	loss :0.6155139803886414
epoch 773	loss :0.6153098940849304
epoch 774	loss :0.6151055693626404
epoch 775	loss :0.6149011254310608
epoch 776	loss :0.6146965622901917
epoch 777	loss :0.6144918203353882
epoch 778	loss :0.6142869591712952
epoch 779	loss :0.6140819191932678
epoch 780	loss :0.6138768196105957
epoch 781	loss :0.6136713624000549
epoch 782	loss :0.6134658455848694
epoch 783	loss :0.6132601499557495
epoch 784	loss :0.6130542755126953
epoch 785	loss :0.6128483414649963
epoch 786	loss :0.6126423478126526
epoch 787	loss :0.612436056137085
epoch 788	loss :0.6122295260429382
epoch 789	loss :0.6120226979255676
epoch 790	loss :0.6118160486221313
epoch 791	loss :0.6116089224815369
epoch 792	loss :0.611

epoch 1021	loss :0.5610588788986206
epoch 1022	loss :0.5608373284339905
epoch 1023	loss :0.5606157779693604
epoch 1024	loss :0.5603939294815063
epoch 1025	loss :0.5601723194122314
epoch 1026	loss :0.5599509477615356
epoch 1027	loss :0.5597293972969055
epoch 1028	loss :0.5595080852508545
epoch 1029	loss :0.5592868328094482
epoch 1030	loss :0.559065580368042
epoch 1031	loss :0.5588445067405701
epoch 1032	loss :0.5586236715316772
epoch 1033	loss :0.5584030151367188
epoch 1034	loss :0.5581824779510498
epoch 1035	loss :0.5579619407653809
epoch 1036	loss :0.5577414631843567
epoch 1037	loss :0.5575211048126221
epoch 1038	loss :0.557300329208374
epoch 1039	loss :0.5570797920227051
epoch 1040	loss :0.55685955286026
epoch 1041	loss :0.5566391944885254
epoch 1042	loss :0.5564191937446594
epoch 1043	loss :0.5561991930007935
epoch 1044	loss :0.5559791326522827
epoch 1045	loss :0.5557592511177063
epoch 1046	loss :0.5555393695831299
epoch 1047	loss :0.5553199052810669
epoch 1048	loss :0.5551003217697

epoch 1268	loss :0.5104341506958008
epoch 1269	loss :0.5102562308311462
epoch 1270	loss :0.5100787281990051
epoch 1271	loss :0.5099013447761536
epoch 1272	loss :0.509724497795105
epoch 1273	loss :0.509548008441925
epoch 1274	loss :0.5093718767166138
epoch 1275	loss :0.509195864200592
epoch 1276	loss :0.5090201497077942
epoch 1277	loss :0.5088445544242859
epoch 1278	loss :0.5086690783500671
epoch 1279	loss :0.508493959903717
epoch 1280	loss :0.5083192586898804
epoch 1281	loss :0.5081447958946228
epoch 1282	loss :0.5079706311225891
epoch 1283	loss :0.5077967643737793
epoch 1284	loss :0.5076234340667725
epoch 1285	loss :0.5074504017829895
epoch 1286	loss :0.5072776675224304
epoch 1287	loss :0.5071054697036743
epoch 1288	loss :0.5069336295127869
epoch 1289	loss :0.5067620873451233
epoch 1290	loss :0.5065908432006836
epoch 1291	loss :0.506419837474823
epoch 1292	loss :0.5062490701675415
epoch 1293	loss :0.5060784816741943
epoch 1294	loss :0.5059083104133606
epoch 1295	loss :0.50573861598968

In [384]:
source_roc_auc = roc_auc_score(Y_source_test, source_model(X_source_test).detach().numpy())
print(f"source auc:\t{source_roc_auc}")

source auc:	0.9806122448979592


# Train target model (no transfer learning)

In [385]:
learning_rate = 0.01
epochs = 1400
# Model , Optimizer, Loss
target_model = BinaryClassification(input_shape=X_target_train.shape[1])
optimizer = torch.optim.SGD(target_model.parameters(),lr=learning_rate)
loss_fn = nn.BCEWithLogitsLoss()

In [386]:
losses = []
accur = []
for i in range(epochs):

    #calculate output
    output = target_model(X_target_train)

    #calculate loss
    loss = loss_fn(output, Y_target_train.reshape(-1,1))

    #backprop
    optimizer.zero_grad()
    loss.backward()
    optimizer.step()

    losses.append(loss)
    print(f"epoch {i}\tloss :{loss}")


epoch 0	loss :0.6886869668960571
epoch 1	loss :0.6885937452316284
epoch 2	loss :0.6885005235671997
epoch 3	loss :0.688407301902771
epoch 4	loss :0.6883140206336975
epoch 5	loss :0.6882206201553345
epoch 6	loss :0.6881272792816162
epoch 7	loss :0.688034176826477
epoch 8	loss :0.6879414319992065
epoch 9	loss :0.6878484487533569
epoch 10	loss :0.6877554059028625
epoch 11	loss :0.6876583099365234
epoch 12	loss :0.6875509023666382
epoch 13	loss :0.6874435544013977
epoch 14	loss :0.6873359680175781
epoch 15	loss :0.6872285008430481
epoch 16	loss :0.6871208548545837
epoch 17	loss :0.6870130896568298
epoch 18	loss :0.6869052648544312
epoch 19	loss :0.6867949366569519
epoch 20	loss :0.6866846084594727
epoch 21	loss :0.6865741014480591
epoch 22	loss :0.6864636540412903
epoch 23	loss :0.6863530278205872
epoch 24	loss :0.6862422227859497
epoch 25	loss :0.6861320734024048
epoch 26	loss :0.6860238909721375
epoch 27	loss :0.6859156489372253
epoch 28	loss :0.6858073472976685
epoch 29	loss :0.685699045

epoch 446	loss :0.6192470788955688
epoch 447	loss :0.6190240383148193
epoch 448	loss :0.6187869906425476
epoch 449	loss :0.6185515522956848
epoch 450	loss :0.6183213591575623
epoch 451	loss :0.6180857419967651
epoch 452	loss :0.6178573369979858
epoch 453	loss :0.6176275014877319
epoch 454	loss :0.6174014210700989
epoch 455	loss :0.6171625256538391
epoch 456	loss :0.6169297099113464
epoch 457	loss :0.6166938543319702
epoch 458	loss :0.6164597272872925
epoch 459	loss :0.6162246465682983
epoch 460	loss :0.6159942150115967
epoch 461	loss :0.6157716512680054
epoch 462	loss :0.6155270338058472
epoch 463	loss :0.6152957081794739
epoch 464	loss :0.6150597333908081
epoch 465	loss :0.6148310303688049
epoch 466	loss :0.6145917177200317
epoch 467	loss :0.6143606901168823
epoch 468	loss :0.6141317486763
epoch 469	loss :0.6139023303985596
epoch 470	loss :0.6136651039123535
epoch 471	loss :0.6134350895881653
epoch 472	loss :0.6132105588912964
epoch 473	loss :0.6129711270332336
epoch 474	loss :0.61274

epoch 897	loss :0.4979868531227112
epoch 898	loss :0.4977652430534363
epoch 899	loss :0.4975219666957855
epoch 900	loss :0.49729442596435547
epoch 901	loss :0.4970915913581848
epoch 902	loss :0.49682897329330444
epoch 903	loss :0.49660834670066833
epoch 904	loss :0.4963688850402832
epoch 905	loss :0.49616461992263794
epoch 906	loss :0.4959105849266052
epoch 907	loss :0.49569472670555115
epoch 908	loss :0.49545350670814514
epoch 909	loss :0.4952293932437897
epoch 910	loss :0.4950028955936432
epoch 911	loss :0.4947965741157532
epoch 912	loss :0.49455294013023376
epoch 913	loss :0.4943463206291199
epoch 914	loss :0.4941089153289795
epoch 915	loss :0.49388664960861206
epoch 916	loss :0.493678480386734
epoch 917	loss :0.49346956610679626
epoch 918	loss :0.49323564767837524
epoch 919	loss :0.4930041432380676
epoch 920	loss :0.4927932620048523
epoch 921	loss :0.4925760328769684
epoch 922	loss :0.4923623204231262
epoch 923	loss :0.49212759733200073
epoch 924	loss :0.4918935298919678
epoch 925	

epoch 1132	loss :0.45249027013778687
epoch 1133	loss :0.45234769582748413
epoch 1134	loss :0.4522162079811096
epoch 1135	loss :0.452086865901947
epoch 1136	loss :0.451964795589447
epoch 1137	loss :0.4518211781978607
epoch 1138	loss :0.4516673982143402
epoch 1139	loss :0.4515405595302582
epoch 1140	loss :0.45140379667282104
epoch 1141	loss :0.4512740969657898
epoch 1142	loss :0.4511451721191406
epoch 1143	loss :0.45103707909584045
epoch 1144	loss :0.4508915841579437
epoch 1145	loss :0.4507372975349426
epoch 1146	loss :0.45060762763023376
epoch 1147	loss :0.45050048828125
epoch 1148	loss :0.45035308599472046
epoch 1149	loss :0.4502124786376953
epoch 1150	loss :0.45008140802383423
epoch 1151	loss :0.4499571919441223
epoch 1152	loss :0.4498358368873596
epoch 1153	loss :0.44969063997268677
epoch 1154	loss :0.4495733380317688
epoch 1155	loss :0.449443519115448
epoch 1156	loss :0.4493083953857422
epoch 1157	loss :0.44917869567871094
epoch 1158	loss :0.44905439019203186
epoch 1159	loss :0.4489

epoch 1359	loss :0.42916440963745117
epoch 1360	loss :0.42908841371536255
epoch 1361	loss :0.42901182174682617
epoch 1362	loss :0.4289330840110779
epoch 1363	loss :0.42886871099472046
epoch 1364	loss :0.4287816882133484
epoch 1365	loss :0.4287167489528656
epoch 1366	loss :0.4286419749259949
epoch 1367	loss :0.428567111492157
epoch 1368	loss :0.4284898340702057
epoch 1369	loss :0.42841267585754395
epoch 1370	loss :0.42833957076072693
epoch 1371	loss :0.4282788336277008
epoch 1372	loss :0.4281948208808899
epoch 1373	loss :0.42812976241111755
epoch 1374	loss :0.4280555248260498
epoch 1375	loss :0.4279800355434418
epoch 1376	loss :0.4279085099697113
epoch 1377	loss :0.42784062027931213
epoch 1378	loss :0.427758127450943
epoch 1379	loss :0.42768630385398865
epoch 1380	loss :0.4276251792907715
epoch 1381	loss :0.427560955286026
epoch 1382	loss :0.4274752736091614
epoch 1383	loss :0.42740529775619507
epoch 1384	loss :0.4273461401462555
epoch 1385	loss :0.42726707458496094
epoch 1386	loss :0.4

In [387]:
target_roc_auc = roc_auc_score(Y_target_test, target_model(X_target_test).detach().numpy())
print(f"target auc:\t{target_roc_auc}")

target auc:	0.77


# Target model with Transfer Learning

In [388]:
import copy

In [389]:
learning_rate = 0.01
epochs = 1400
# Model , Optimizer, Loss
target_model_tl = copy.deepcopy(source_model)


In [390]:
# freeze layers by so the weights do not update
for param in target_model_tl.parameters():
    param.requires_grad = False

In [391]:
# reassing last layer with requires_grad=true by default
target_model_tl.layer_out = nn.Linear(target_model_tl.layer_out.in_features, target_model_tl.layer_out.out_features)

In [392]:
optimizer = torch.optim.SGD(target_model_tl.parameters(),lr=learning_rate)
loss_fn = nn.BCEWithLogitsLoss()

In [393]:
losses = []
accur = []
for i in range(epochs):

    #calculate output
    output = target_model_tl(X_target_train)

    #calculate loss
    loss = loss_fn(output, Y_target_train.reshape(-1,1))

    #backprop
    optimizer.zero_grad()
    loss.backward()
    optimizer.step()

    losses.append(loss)
    print(f"epoch {i}\tloss :{loss}")


epoch 0	loss :0.6808439493179321
epoch 1	loss :0.6806464195251465
epoch 2	loss :0.6804677844047546
epoch 3	loss :0.6802900433540344
epoch 4	loss :0.6801133155822754
epoch 5	loss :0.6799375414848328
epoch 6	loss :0.6797627210617065
epoch 7	loss :0.6795889735221863
epoch 8	loss :0.6794160604476929
epoch 9	loss :0.6792441010475159
epoch 10	loss :0.6790729761123657
epoch 11	loss :0.6789029836654663
epoch 12	loss :0.6787183880805969
epoch 13	loss :0.678497314453125
epoch 14	loss :0.6782774925231934
epoch 15	loss :0.6780589818954468
epoch 16	loss :0.67784184217453
epoch 17	loss :0.6776260137557983
epoch 18	loss :0.6774114370346069
epoch 19	loss :0.6771982312202454
epoch 20	loss :0.6769862771034241
epoch 21	loss :0.6767755746841431
epoch 22	loss :0.6765661835670471
epoch 23	loss :0.6763452887535095
epoch 24	loss :0.6760978102684021
epoch 25	loss :0.6758519411087036
epoch 26	loss :0.6756076812744141
epoch 27	loss :0.6753610968589783
epoch 28	loss :0.6750823259353638
epoch 29	loss :0.6748052835

epoch 259	loss :0.6281794309616089
epoch 260	loss :0.6280785799026489
epoch 261	loss :0.6279783844947815
epoch 262	loss :0.6278786063194275
epoch 263	loss :0.6277794241905212
epoch 264	loss :0.6276808381080627
epoch 265	loss :0.6275827884674072
epoch 266	loss :0.6274852752685547
epoch 267	loss :0.6273882985115051
epoch 268	loss :0.6272919178009033
epoch 269	loss :0.6271926164627075
epoch 270	loss :0.6270812749862671
epoch 271	loss :0.6269705891609192
epoch 272	loss :0.6268606781959534
epoch 273	loss :0.6267513632774353
epoch 274	loss :0.6266427636146545
epoch 275	loss :0.6265348196029663
epoch 276	loss :0.6264274716377258
epoch 277	loss :0.6263208389282227
epoch 278	loss :0.6262148022651672
epoch 279	loss :0.6261094808578491
epoch 280	loss :0.6260046362876892
epoch 281	loss :0.6259006261825562
epoch 282	loss :0.6257971525192261
epoch 283	loss :0.6256942749023438
epoch 284	loss :0.625592052936554
epoch 285	loss :0.6254903674125671
epoch 286	loss :0.6253935098648071
epoch 287	loss :0.625

epoch 520	loss :0.6128738522529602
epoch 521	loss :0.6128369569778442
epoch 522	loss :0.612800121307373
epoch 523	loss :0.6127633452415466
epoch 524	loss :0.6127267479896545
epoch 525	loss :0.6126901507377625
epoch 526	loss :0.6126537322998047
epoch 527	loss :0.6126173734664917
epoch 528	loss :0.6125810146331787
epoch 529	loss :0.6125448346138
epoch 530	loss :0.6125086545944214
epoch 531	loss :0.612472653388977
epoch 532	loss :0.6124366521835327
epoch 533	loss :0.6124008297920227
epoch 534	loss :0.6123650074005127
epoch 535	loss :0.6123292446136475
epoch 536	loss :0.6122936010360718
epoch 537	loss :0.6122580766677856
epoch 538	loss :0.6122225522994995
epoch 539	loss :0.6121872067451477
epoch 540	loss :0.6121519207954407
epoch 541	loss :0.6121166944503784
epoch 542	loss :0.6120814085006714
epoch 543	loss :0.6120463609695435
epoch 544	loss :0.6120113730430603
epoch 545	loss :0.6119764447212219
epoch 546	loss :0.6119415760040283
epoch 547	loss :0.6119067668914795
epoch 548	loss :0.6118720

epoch 774	loss :0.6057284474372864
epoch 775	loss :0.6057072877883911
epoch 776	loss :0.6056859493255615
epoch 777	loss :0.605664849281311
epoch 778	loss :0.605643630027771
epoch 779	loss :0.6056225895881653
epoch 780	loss :0.6056015491485596
epoch 781	loss :0.6055805087089539
epoch 782	loss :0.6055595278739929
epoch 783	loss :0.6055386066436768
epoch 784	loss :0.6055176854133606
epoch 785	loss :0.6054968237876892
epoch 786	loss :0.6054760217666626
epoch 787	loss :0.605455219745636
epoch 788	loss :0.6054344773292542
epoch 789	loss :0.6054137945175171
epoch 790	loss :0.6053930521011353
epoch 791	loss :0.6053724884986877
epoch 792	loss :0.6053518056869507
epoch 793	loss :0.605331301689148
epoch 794	loss :0.6053107976913452
epoch 795	loss :0.6052902936935425
epoch 796	loss :0.6052698493003845
epoch 797	loss :0.6052494645118713
epoch 798	loss :0.6052290201187134
epoch 799	loss :0.6052087545394897
epoch 800	loss :0.6051883697509766
epoch 801	loss :0.6051681637763977
epoch 802	loss :0.605147

epoch 1044	loss :0.600805938243866
epoch 1045	loss :0.6007910966873169
epoch 1046	loss :0.6007761359214783
epoch 1047	loss :0.6007611751556396
epoch 1048	loss :0.6007461547851562
epoch 1049	loss :0.6007314920425415
epoch 1050	loss :0.6007165908813477
epoch 1051	loss :0.600701630115509
epoch 1052	loss :0.6006867289543152
epoch 1053	loss :0.6006721258163452
epoch 1054	loss :0.6006573438644409
epoch 1055	loss :0.6006425023078918
epoch 1056	loss :0.6006277203559875
epoch 1057	loss :0.600612998008728
epoch 1058	loss :0.6005984544754028
epoch 1059	loss :0.6005836725234985
epoch 1060	loss :0.600568950176239
epoch 1061	loss :0.600554347038269
epoch 1062	loss :0.6005398631095886
epoch 1063	loss :0.6005252599716187
epoch 1064	loss :0.6005106568336487
epoch 1065	loss :0.6004959940910339
epoch 1066	loss :0.6004815697669983
epoch 1067	loss :0.6004670858383179
epoch 1068	loss :0.6004525423049927
epoch 1069	loss :0.6004379987716675
epoch 1070	loss :0.6004235148429871
epoch 1071	loss :0.60040920972824

epoch 1311	loss :0.597609281539917
epoch 1312	loss :0.597599446773529
epoch 1313	loss :0.5975895524024963
epoch 1314	loss :0.5975795984268188
epoch 1315	loss :0.5975697636604309
epoch 1316	loss :0.5975598096847534
epoch 1317	loss :0.5975500345230103
epoch 1318	loss :0.5975401401519775
epoch 1319	loss :0.5975303649902344
epoch 1320	loss :0.5975204706192017
epoch 1321	loss :0.5975106954574585
epoch 1322	loss :0.5975008606910706
epoch 1323	loss :0.5974910259246826
epoch 1324	loss :0.5974812507629395
epoch 1325	loss :0.5974715352058411
epoch 1326	loss :0.5974617004394531
epoch 1327	loss :0.5974519848823547
epoch 1328	loss :0.5974421501159668
epoch 1329	loss :0.5974324345588684
epoch 1330	loss :0.59742271900177
epoch 1331	loss :0.5974129438400269
epoch 1332	loss :0.5974032878875732
epoch 1333	loss :0.5973935127258301
epoch 1334	loss :0.5973838567733765
epoch 1335	loss :0.5973741412162781
epoch 1336	loss :0.5973644852638245
epoch 1337	loss :0.5973547697067261
epoch 1338	loss :0.5973451137542

In [394]:
target_tl_roc_auc = roc_auc_score(Y_target_test, target_model_tl(X_target_test).detach().numpy())
print(f"target TL auc:\t{target_tl_roc_auc}")

target TL auc:	0.94
