# generate data

In [1]:
import scipy
import numpy as np
import math
# rs = np.random.RandomState(1978)

In [2]:
def sigmoid(x):
  return 1 / (1 + math.exp(-x))

In [3]:
theta_S = [3., 4., 5., 6.]
theta_T = [3., 4., 5., 3.]
p = len(theta_T)

In [4]:
n_source = 10000
n_target = 100

In [5]:
X_S = np.random.normal(0, 1, size = (n_source, p))
X_T = np.random.normal(0, 1, size = (n_target, p))

In [6]:
g_S = X_S @ theta_S
g_T = X_T @ theta_T

In [7]:
y_S = [np.random.binomial(1, sigmoid(g_Si), 1)[0] for g_Si in g_S]
y_T = [np.random.binomial(1, sigmoid(g_Ti), 1)[0] for g_Ti in g_T]

# Pytorch NN and data setup

In [8]:
from torch import nn
from torch.nn import functional as F
import torch
from sklearn.metrics import roc_auc_score

In [9]:
class BinaryClassification(nn.Module):
    def __init__(self, input_shape):
        super(BinaryClassification, self).__init__()
        self.layer_1 = nn.Linear(input_shape, 10) 
        self.layer_out = nn.Linear(10, 1) 

        
    def forward(self, inputs):
        x = torch.relu(self.layer_1(inputs))
        x = torch.relu(self.layer_out(x))
        
        return x

In [10]:
X_S_train = torch.tensor(X_S[0:8000, :].astype(np.float32), dtype=torch.float32)
X_S_test = torch.tensor(X_S[8001:, :].astype(np.float32), dtype=torch.float32)
y_S_train = torch.tensor(y_S[0:8000], dtype=torch.float32)
y_S_test = torch.tensor(y_S[8001:], dtype=torch.float32)


X_T_train = torch.tensor(X_T[0:80, :].astype(np.float32), dtype=torch.float32)
X_T_test = torch.tensor(X_T[80:, :].astype(np.float32), dtype=torch.float32)
y_T_train = torch.tensor(y_T[0:80], dtype=torch.float32)
y_T_test = torch.tensor(y_T[80:], dtype=torch.float32)

# Train source model

In [11]:
learning_rate = 0.01
epochs = 1400
# Model , Optimizer, Loss
source_model = BinaryClassification(input_shape=X_S_train.shape[1])
optimizer = torch.optim.SGD(source_model.parameters(),lr=learning_rate)
loss_fn = nn.BCEWithLogitsLoss()

In [12]:
losses = []
accur = []
for i in range(epochs):

    #calculate output
    output = source_model(X_S_train)

    #calculate loss
    loss = loss_fn(output,y_S_train.reshape(-1,1))

    #backprop
    optimizer.zero_grad()
    loss.backward()
    optimizer.step()

    losses.append(loss)
    print(f"epoch {i}\tloss :{loss}")


epoch 0	loss :0.6914188265800476
epoch 1	loss :0.6913228631019592
epoch 2	loss :0.6912240386009216
epoch 3	loss :0.6911231875419617
epoch 4	loss :0.6910192370414734
epoch 5	loss :0.6909112930297852
epoch 6	loss :0.6908011436462402
epoch 7	loss :0.6906874775886536
epoch 8	loss :0.690571129322052
epoch 9	loss :0.6904516816139221
epoch 10	loss :0.6903294920921326
epoch 11	loss :0.6902040839195251
epoch 12	loss :0.6900757551193237
epoch 13	loss :0.6899462938308716
epoch 14	loss :0.6898165941238403
epoch 15	loss :0.6896845698356628
epoch 16	loss :0.6895511150360107
epoch 17	loss :0.6894152164459229
epoch 18	loss :0.6892784833908081
epoch 19	loss :0.6891405582427979
epoch 20	loss :0.689000129699707
epoch 21	loss :0.6888564229011536
epoch 22	loss :0.6887089610099792
epoch 23	loss :0.688558042049408
epoch 24	loss :0.6884051561355591
epoch 25	loss :0.6882496476173401
epoch 26	loss :0.6880915760993958
epoch 27	loss :0.6879312992095947
epoch 28	loss :0.6877679228782654
epoch 29	loss :0.6876007318

epoch 453	loss :0.5888797640800476
epoch 454	loss :0.5886512398719788
epoch 455	loss :0.588422954082489
epoch 456	loss :0.5881947875022888
epoch 457	loss :0.5879666805267334
epoch 458	loss :0.5877385139465332
epoch 459	loss :0.5875105261802673
epoch 460	loss :0.5872825980186462
epoch 461	loss :0.5870546698570251
epoch 462	loss :0.5868269801139832
epoch 463	loss :0.5865991115570068
epoch 464	loss :0.5863710641860962
epoch 465	loss :0.5861430168151855
epoch 466	loss :0.5859149694442749
epoch 467	loss :0.5856867432594299
epoch 468	loss :0.5854583978652954
epoch 469	loss :0.5852302312850952
epoch 470	loss :0.585002064704895
epoch 471	loss :0.5847738981246948
epoch 472	loss :0.5845457315444946
epoch 473	loss :0.5843175649642944
epoch 474	loss :0.5840896964073181
epoch 475	loss :0.5838618874549866
epoch 476	loss :0.5836338400840759
epoch 477	loss :0.5834062099456787
epoch 478	loss :0.5831785798072815
epoch 479	loss :0.5829510688781738
epoch 480	loss :0.5827236175537109
epoch 481	loss :0.5824

epoch 756	loss :0.5242977738380432
epoch 757	loss :0.5241155028343201
epoch 758	loss :0.5239334106445312
epoch 759	loss :0.5237517356872559
epoch 760	loss :0.5235704779624939
epoch 761	loss :0.5233895778656006
epoch 762	loss :0.5232087969779968
epoch 763	loss :0.5230282545089722
epoch 764	loss :0.5228481292724609
epoch 765	loss :0.5226680636405945
epoch 766	loss :0.5224884152412415
epoch 767	loss :0.5223088264465332
epoch 768	loss :0.5221295356750488
epoch 769	loss :0.5219504833221436
epoch 770	loss :0.5217716693878174
epoch 771	loss :0.5215930342674255
epoch 772	loss :0.5214146375656128
epoch 773	loss :0.5212363004684448
epoch 774	loss :0.5210584998130798
epoch 775	loss :0.5208811163902283
epoch 776	loss :0.5207040905952454
epoch 777	loss :0.5205275416374207
epoch 778	loss :0.5203512907028198
epoch 779	loss :0.5201750993728638
epoch 780	loss :0.5199992060661316
epoch 781	loss :0.5198236703872681
epoch 782	loss :0.519648551940918
epoch 783	loss :0.5194737911224365
epoch 784	loss :0.519

epoch 1210	loss :0.47169041633605957
epoch 1211	loss :0.4716266691684723
epoch 1212	loss :0.4715629816055298
epoch 1213	loss :0.4714994430541992
epoch 1214	loss :0.4714360237121582
epoch 1215	loss :0.47137269377708435
epoch 1216	loss :0.47130951285362244
epoch 1217	loss :0.4712463319301605
epoch 1218	loss :0.47118330001831055
epoch 1219	loss :0.4711204469203949
epoch 1220	loss :0.4710577428340912
epoch 1221	loss :0.4709951877593994
epoch 1222	loss :0.4709327816963196
epoch 1223	loss :0.4708704352378845
epoch 1224	loss :0.470808207988739
epoch 1225	loss :0.470746248960495
epoch 1226	loss :0.47068437933921814
epoch 1227	loss :0.47062259912490845
epoch 1228	loss :0.4705609083175659
epoch 1229	loss :0.4704993665218353
epoch 1230	loss :0.4704378843307495
epoch 1231	loss :0.4703766107559204
epoch 1232	loss :0.47031545639038086
epoch 1233	loss :0.47025439143180847
epoch 1234	loss :0.47019341588020325
epoch 1235	loss :0.47013258934020996
epoch 1236	loss :0.4700719118118286
epoch 1237	loss :0.4

In [13]:
source_roc_auc = roc_auc_score(y_S_test, source_model(X_S_test).detach().numpy())
print(f"source auc:\t{source_roc_auc}")

source auc:	0.9915998426726388


# Train target model (no transfer learning)

In [14]:
learning_rate = 0.01
epochs = 1400
# Model , Optimizer, Loss
target_model = BinaryClassification(input_shape=X_T_train.shape[1])
optimizer = torch.optim.SGD(target_model.parameters(),lr=learning_rate)
loss_fn = nn.BCEWithLogitsLoss()

In [15]:
losses = []
accur = []
for i in range(epochs):

    #calculate output
    output = target_model(X_T_train)

    #calculate loss
    loss = loss_fn(output,y_T_train.reshape(-1,1))

    #backprop
    optimizer.zero_grad()
    loss.backward()
    optimizer.step()

    losses.append(loss)
    print(f"epoch {i}\tloss :{loss}")


epoch 0	loss :0.7063497304916382
epoch 1	loss :0.7061512470245361
epoch 2	loss :0.7059536576271057
epoch 3	loss :0.7057567834854126
epoch 4	loss :0.7055525779724121
epoch 5	loss :0.7053369283676147
epoch 6	loss :0.7051223516464233
epoch 7	loss :0.7049084901809692
epoch 8	loss :0.7046957015991211
epoch 9	loss :0.7044870257377625
epoch 10	loss :0.7042968273162842
epoch 11	loss :0.7041073441505432
epoch 12	loss :0.7039186954498291
epoch 13	loss :0.7037307024002075
epoch 14	loss :0.7035435438156128
epoch 15	loss :0.7033571004867554
epoch 16	loss :0.7031713724136353
epoch 17	loss :0.702986478805542
epoch 18	loss :0.7028021812438965
epoch 19	loss :0.7026186585426331
epoch 20	loss :0.7024358510971069
epoch 21	loss :0.7022538185119629
epoch 22	loss :0.7020663022994995
epoch 23	loss :0.7018777132034302
epoch 24	loss :0.7017021775245667
epoch 25	loss :0.7015269994735718
epoch 26	loss :0.7013585567474365
epoch 27	loss :0.7012028694152832
epoch 28	loss :0.7010480165481567
epoch 29	loss :0.70089393

epoch 434	loss :0.6870590448379517
epoch 435	loss :0.6870492696762085
epoch 436	loss :0.6870394349098206
epoch 437	loss :0.6870299577713013
epoch 438	loss :0.6870200037956238
epoch 439	loss :0.6870101690292358
epoch 440	loss :0.6870003938674927
epoch 441	loss :0.6869905591011047
epoch 442	loss :0.6869806051254272
epoch 443	loss :0.6869708299636841
epoch 444	loss :0.68696129322052
epoch 445	loss :0.6869513988494873
epoch 446	loss :0.6869415044784546
epoch 447	loss :0.6869316697120667
epoch 448	loss :0.6869217753410339
epoch 449	loss :0.6869118809700012
epoch 450	loss :0.6869021058082581
epoch 451	loss :0.6868923306465149
epoch 452	loss :0.686882495880127
epoch 453	loss :0.6868726015090942
epoch 454	loss :0.686862587928772
epoch 455	loss :0.686852753162384
epoch 456	loss :0.6868428587913513
epoch 457	loss :0.6868330836296082
epoch 458	loss :0.6868232488632202
epoch 459	loss :0.6868134140968323
epoch 460	loss :0.6868034601211548
epoch 461	loss :0.6867934465408325
epoch 462	loss :0.6867834

epoch 864	loss :0.6773687601089478
epoch 865	loss :0.6773241758346558
epoch 866	loss :0.6772797107696533
epoch 867	loss :0.6772352457046509
epoch 868	loss :0.6771907806396484
epoch 869	loss :0.677146315574646
epoch 870	loss :0.677101731300354
epoch 871	loss :0.6770573854446411
epoch 872	loss :0.6770128011703491
epoch 873	loss :0.6769683957099915
epoch 874	loss :0.676923930644989
epoch 875	loss :0.6768794655799866
epoch 876	loss :0.6768350601196289
epoch 877	loss :0.6767905950546265
epoch 878	loss :0.6767460703849792
epoch 879	loss :0.6767016649246216
epoch 880	loss :0.6766573190689087
epoch 881	loss :0.6766129732131958
epoch 882	loss :0.6765683889389038
epoch 883	loss :0.6765240430831909
epoch 884	loss :0.6764796376228333
epoch 885	loss :0.6764352321624756
epoch 886	loss :0.6763908267021179
epoch 887	loss :0.6763464212417603
epoch 888	loss :0.6763020753860474
epoch 889	loss :0.6762576103210449
epoch 890	loss :0.676213264465332
epoch 891	loss :0.6761688590049744
epoch 892	loss :0.676124

epoch 1313	loss :0.5922451615333557
epoch 1314	loss :0.5920246839523315
epoch 1315	loss :0.5918058156967163
epoch 1316	loss :0.5915879011154175
epoch 1317	loss :0.5913702249526978
epoch 1318	loss :0.591152548789978
epoch 1319	loss :0.5909349918365479
epoch 1320	loss :0.590717613697052
epoch 1321	loss :0.5905002355575562
epoch 1322	loss :0.5902830362319946
epoch 1323	loss :0.5900660157203674
epoch 1324	loss :0.5898489952087402
epoch 1325	loss :0.5896345376968384
epoch 1326	loss :0.5894203782081604
epoch 1327	loss :0.5892060995101929
epoch 1328	loss :0.5889919996261597
epoch 1329	loss :0.5887779593467712
epoch 1330	loss :0.5885640382766724
epoch 1331	loss :0.588350236415863
epoch 1332	loss :0.5881367325782776
epoch 1333	loss :0.5879233479499817
epoch 1334	loss :0.5877098441123962
epoch 1335	loss :0.5874965190887451
epoch 1336	loss :0.5872833132743835
epoch 1337	loss :0.5870707035064697
epoch 1338	loss :0.5868576765060425
epoch 1339	loss :0.5866446495056152
epoch 1340	loss :0.586432337760

In [16]:
target_roc_auc = roc_auc_score(y_T_test, target_model(X_T_test).detach().numpy())
print(f"target auc:\t{target_roc_auc}")

target auc:	0.925


# Target model with Transfer Learning

In [17]:
import copy

In [18]:
learning_rate = 0.01
epochs = 1400
# Model , Optimizer, Loss
target_model_tl = copy.deepcopy(source_model)


In [19]:
# freeze layers by so the weights do not update
for param in target_model_tl.parameters():
    param.requires_grad = False

In [20]:
# reassing last layer with requires_grad=true by default
target_model_tl.layer_out = nn.Linear(target_model_tl.layer_out.in_features, target_model_tl.layer_out.out_features)

In [21]:
optimizer = torch.optim.SGD(target_model_tl.parameters(),lr=learning_rate)
loss_fn = nn.BCEWithLogitsLoss()

In [22]:
losses = []
accur = []
for i in range(epochs):

    #calculate output
    output = target_model_tl(X_T_train)

    #calculate loss
    loss = loss_fn(output,y_T_train.reshape(-1,1))

    #backprop
    optimizer.zero_grad()
    loss.backward()
    optimizer.step()

    losses.append(loss)
    print(f"epoch {i}\tloss :{loss}")


epoch 0	loss :0.6630159616470337
epoch 1	loss :0.6622641086578369
epoch 2	loss :0.6615195274353027
epoch 3	loss :0.6607820391654968
epoch 4	loss :0.660051703453064
epoch 5	loss :0.6593283414840698
epoch 6	loss :0.6586119532585144
epoch 7	loss :0.6579023599624634
epoch 8	loss :0.657199501991272
epoch 9	loss :0.6565061807632446
epoch 10	loss :0.6558586955070496
epoch 11	loss :0.6552068591117859
epoch 12	loss :0.654466986656189
epoch 13	loss :0.6537343263626099
epoch 14	loss :0.6530085802078247
epoch 15	loss :0.652267336845398
epoch 16	loss :0.6515241861343384
epoch 17	loss :0.6507692337036133
epoch 18	loss :0.6500217914581299
epoch 19	loss :0.6492816805839539
epoch 20	loss :0.6485317945480347
epoch 21	loss :0.6477193832397461
epoch 22	loss :0.646915078163147
epoch 23	loss :0.6461189389228821
epoch 24	loss :0.6453307271003723
epoch 25	loss :0.6445504426956177
epoch 26	loss :0.6437792778015137
epoch 27	loss :0.6430329084396362
epoch 28	loss :0.6422614455223083
epoch 29	loss :0.641502499580

epoch 510	loss :0.5500658750534058
epoch 511	loss :0.5500124096870422
epoch 512	loss :0.5499590635299683
epoch 513	loss :0.5499058365821838
epoch 514	loss :0.5498526692390442
epoch 515	loss :0.5497996807098389
epoch 516	loss :0.5497468113899231
epoch 517	loss :0.5496939420700073
epoch 518	loss :0.5496412515640259
epoch 519	loss :0.5495887398719788
epoch 520	loss :0.5495362281799316
epoch 521	loss :0.5494838356971741
epoch 522	loss :0.5494316220283508
epoch 523	loss :0.5493794679641724
epoch 524	loss :0.5493273138999939
epoch 525	loss :0.5492753982543945
epoch 526	loss :0.5492235422134399
epoch 527	loss :0.5491718649864197
epoch 528	loss :0.5491201281547546
epoch 529	loss :0.549068808555603
epoch 530	loss :0.5490171909332275
epoch 531	loss :0.5489659309387207
epoch 532	loss :0.5489146113395691
epoch 533	loss :0.5488635897636414
epoch 534	loss :0.5488124489784241
epoch 535	loss :0.5487616658210754
epoch 536	loss :0.5487107038497925
epoch 537	loss :0.5486599802970886
epoch 538	loss :0.548

epoch 1034	loss :0.5319386720657349
epoch 1035	loss :0.5319135189056396
epoch 1036	loss :0.531888484954834
epoch 1037	loss :0.5318633913993835
epoch 1038	loss :0.5318384170532227
epoch 1039	loss :0.5318132638931274
epoch 1040	loss :0.5317882299423218
epoch 1041	loss :0.5317632555961609
epoch 1042	loss :0.5317384004592896
epoch 1043	loss :0.5317133069038391
epoch 1044	loss :0.5316883325576782
epoch 1045	loss :0.5316634774208069
epoch 1046	loss :0.531638503074646
epoch 1047	loss :0.5316135883331299
epoch 1048	loss :0.5315886735916138
epoch 1049	loss :0.5315637588500977
epoch 1050	loss :0.5315390229225159
epoch 1051	loss :0.5315141677856445
epoch 1052	loss :0.5314893126487732
epoch 1053	loss :0.5314644575119019
epoch 1054	loss :0.5314396619796753
epoch 1055	loss :0.5314148664474487
epoch 1056	loss :0.5313901305198669
epoch 1057	loss :0.5313653945922852
epoch 1058	loss :0.5313406586647034
epoch 1059	loss :0.5313159227371216
epoch 1060	loss :0.5312911868095398
epoch 1061	loss :0.53126657009

In [23]:
target_tl_roc_auc = roc_auc_score(y_T_test, target_model_tl(X_T_test).detach().numpy())
print(f"target TL auc:\t{target_tl_roc_auc}")

target TL auc:	1.0
