In [1]:
import os
os.chdir('../src')

In [52]:


import numpy as np

import torch
import torch.nn as nn
import torch.optim as optim
import pandas as pd
from datasets.load_d4ls import load_full_anndata
from sklearn.metrics import pairwise_distances
from sklearn.preprocessing import MinMaxScaler
from sklearn.model_selection import GridSearchCV, train_test_split
from sklearn.metrics import accuracy_score, classification_report
from sklearn.neural_network import MLPClassifier
import time
from sklearn.model_selection import KFold

pd.set_option('display.max_columns', None)

In [5]:
def get_edge_index(pos, sample_ids, distance_thres):
    # construct edge indexes when there is region information
    edge_list = []
    sample_ids_unique = np.unique(sample_ids)
    for sample_id in sample_ids_unique:
        locs = np.where(sample_ids == sample_id)[0]
        pos_region = pos[locs, :]
        dists = pairwise_distances(pos_region)
        dists_mask = dists < distance_thres
        np.fill_diagonal(dists_mask, 0)
        region_edge_list = np.transpose(np.nonzero(dists_mask)).tolist()
        for i, j in region_edge_list:
            edge_list.append([locs[i], locs[j]])
    return edge_list

In [6]:
def get_train_test_masks(train_anndata, test_count=0):
    sample_ids = train_anndata.obs["sample_id"]
    sample_ids_unique = np.unique(sample_ids)

    sample_ids_idx = np.random.choice(np.arange(len(sample_ids_unique)), test_count, replace=False)
    test_sample_ids_mask = np.zeros_like(sample_ids_unique, dtype=bool)
    test_sample_ids_mask[sample_ids_idx] = True

    test_unique_sample_ids = sample_ids_unique[test_sample_ids_mask]

    test_mask = sample_ids.isin(test_unique_sample_ids)
    train_mask = ~test_mask

    return train_mask, test_mask

In [7]:
def prepare_data(train_anndata, make_graph=False, test_samples=10):
    train_mask, test_mask = get_train_test_masks(train_anndata, test_samples)

    X = train_anndata.layers['exprs']
    X_train = X[train_mask]
    X_test = X[test_mask]

    pos = train_anndata.obs[["Pos_X", "Pos_Y"]].values
    pos_train = pos[train_mask]
    pos_test = pos[test_mask]

    if make_graph:
        sample_ids = train_anndata.obs["sample_id"]
        test_sample_ids = sample_ids[test_mask]
        train_sample_ids = sample_ids[train_mask]

        edges_train = get_edge_index(pos_train, train_sample_ids, 10)
        edges_test = get_edge_index(pos_test, test_sample_ids, 10)
    else:
        edges_train = None
        edges_test = None

    cell_types = np.sort(list(set(train_anndata.obs["cell_labels"].values))).tolist()
    # we here map class in texts to categorical numbers and also save an inverse_dict to map the numbers back to texts
    cell_type_dict = {}
    inverse_dict = {}
    for i, cell_type in enumerate(cell_types):
        cell_type_dict[cell_type] = i
        inverse_dict[i] = cell_type
        
    Y_train = train_anndata.obs["cell_labels"].values[train_mask]
    Y_test = train_anndata.obs["cell_labels"].values[test_mask]

    Y_train = np.array([cell_type_dict[x] for x in Y_train])
    Y_test = np.array([cell_type_dict[x] for x in Y_test])

    return X_train, Y_train, edges_train, X_test, Y_test, edges_test, inverse_dict



In [8]:
train_anndata = load_full_anndata()

X_train, Y_train, edges_train, X_test, Y_test, edges_test, inverse_dict = prepare_data(train_anndata)

In [90]:
train_anndata = load_full_anndata()

In [93]:
train_anndata.obs

Unnamed: 0,image,sample_id,ObjectNumber,Pos_X,Pos_Y,area,major_axis_length,minor_axis_length,eccentricity,width_px,height_px,acquisition_id,SlideId,Study,Box.Description,Position,SampleId,Indication,BatchId,SubBatchId,ROI,ROIonSlide,includeImage,flag_no_cells,flag_no_ROI,flag_total_area,flag_percent_covered,small_cell,celltypes,flag_tumor,PD1_pos,Ki67_pos,cleavedPARP_pos,GrzB_pos,tumor_patches,distToCells,CD20_patches,Batch,cell_labels
IMMUcan_batch20191023_10032145-THOR-VAR-TIS-01-IMC-01_002.tiff_1,IMMUcan_batch20191023_10032145-THOR-VAR-TIS-01...,IMMUcan_batch20191023_10032145-THOR-VAR-TIS-01...,1.0,300.846154,0.692308,13.0,6.094800,2.780135,0.889904,600.0,600.0,2.0,10032145-THOR-VAR-TIS-01-IMC-01,180305_THOR,Slide_IMC-TIS-01,1,10032145-THOR-VAR-TIS-01-PB,THOR,Batch20191023,Batch20191023_01,2,2,1,0,0,0,0,0,MacCD163,0,0,0,0,0,1,8.773580,,Batch20191023,MacCD163
IMMUcan_batch20191023_10032145-THOR-VAR-TIS-01-IMC-01_002.tiff_3,IMMUcan_batch20191023_10032145-THOR-VAR-TIS-01...,IMMUcan_batch20191023_10032145-THOR-VAR-TIS-01...,3.0,26.982143,0.928571,56.0,21.520654,3.368407,0.987675,600.0,600.0,2.0,10032145-THOR-VAR-TIS-01-IMC-01,180305_THOR,Slide_IMC-TIS-01,1,10032145-THOR-VAR-TIS-01-PB,THOR,Batch20191023,Batch20191023_01,2,2,1,0,0,0,0,0,Mural,0,0,0,0,0,0,72.247393,,Batch20191023,Mural
IMMUcan_batch20191023_10032145-THOR-VAR-TIS-01-IMC-01_002.tiff_5,IMMUcan_batch20191023_10032145-THOR-VAR-TIS-01...,IMMUcan_batch20191023_10032145-THOR-VAR-TIS-01...,5.0,309.083333,0.750000,12.0,5.294329,2.862220,0.841267,600.0,600.0,2.0,10032145-THOR-VAR-TIS-01-IMC-01,180305_THOR,Slide_IMC-TIS-01,1,10032145-THOR-VAR-TIS-01-PB,THOR,Batch20191023,Batch20191023_01,2,2,1,0,0,0,0,0,DC,0,0,0,0,0,1,16.982199,,Batch20191023,DC
IMMUcan_batch20191023_10032145-THOR-VAR-TIS-01-IMC-01_002.tiff_7,IMMUcan_batch20191023_10032145-THOR-VAR-TIS-01...,IMMUcan_batch20191023_10032145-THOR-VAR-TIS-01...,7.0,431.916667,0.750000,12.0,5.294329,2.862220,0.841267,600.0,600.0,2.0,10032145-THOR-VAR-TIS-01-IMC-01,180305_THOR,Slide_IMC-TIS-01,1,10032145-THOR-VAR-TIS-01-PB,THOR,Batch20191023,Batch20191023_01,2,2,1,0,0,0,0,0,Tumor,0,0,0,0,0,1,-8.314676,,Batch20191023,Tumor
IMMUcan_batch20191023_10032145-THOR-VAR-TIS-01-IMC-01_002.tiff_8,IMMUcan_batch20191023_10032145-THOR-VAR-TIS-01...,IMMUcan_batch20191023_10032145-THOR-VAR-TIS-01...,8.0,116.931034,1.206897,29.0,9.216670,4.112503,0.894932,600.0,600.0,2.0,10032145-THOR-VAR-TIS-01-IMC-01,180305_THOR,Slide_IMC-TIS-01,1,10032145-THOR-VAR-TIS-01-PB,THOR,Batch20191023,Batch20191023_01,2,2,1,0,0,0,0,0,Tumor,0,0,0,0,0,1,-15.358007,,Batch20191023,Tumor
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
IMMUcan_Batch20220908_S-220729-00002_002.tiff_2713,IMMUcan_Batch20220908_S-220729-00002_002.tiff,IMMUcan_Batch20220908_S-220729-00002_002,2713.0,596.548387,596.709677,31.0,6.857501,5.700162,0.555928,600.0,600.0,2.0,S-220729-00002,190370_SPECT,Slide_IMC-TIS-01,1,10072133-SPECT-VAR-TIS-01-PB,BREAS,Batch20220908,Batch20220908_04,2,2,1,0,0,0,0,0,Mural,0,0,0,0,0,0,85.376518,,Batch20220908,Mural
IMMUcan_Batch20220908_S-220729-00002_002.tiff_2715,IMMUcan_Batch20220908_S-220729-00002_002.tiff,IMMUcan_Batch20220908_S-220729-00002_002,2715.0,180.300000,597.400000,20.0,6.484816,3.840203,0.805803,600.0,600.0,2.0,S-220729-00002,190370_SPECT,Slide_IMC-TIS-01,1,10072133-SPECT-VAR-TIS-01-PB,BREAS,Batch20220908,Batch20220908_04,2,2,1,0,0,0,0,0,Mural,0,0,0,0,0,1,10.318477,,Batch20220908,Mural
IMMUcan_Batch20220908_S-220729-00002_002.tiff_2721,IMMUcan_Batch20220908_S-220729-00002_002.tiff,IMMUcan_Batch20220908_S-220729-00002_002,2721.0,48.370370,598.111111,27.0,10.732613,3.134663,0.956397,600.0,600.0,2.0,S-220729-00002,190370_SPECT,Slide_IMC-TIS-01,1,10072133-SPECT-VAR-TIS-01-PB,BREAS,Batch20220908,Batch20220908_04,2,2,1,0,0,0,0,0,CD8,0,0,0,0,0,1,14.074760,,Batch20220908,CD8
IMMUcan_Batch20220908_S-220729-00002_002.tiff_2722,IMMUcan_Batch20220908_S-220729-00002_002.tiff,IMMUcan_Batch20220908_S-220729-00002_002,2722.0,207.969697,598.060606,33.0,12.864691,3.228974,0.967988,600.0,600.0,2.0,S-220729-00002,190370_SPECT,Slide_IMC-TIS-01,1,10072133-SPECT-VAR-TIS-01-PB,BREAS,Batch20220908,Batch20220908_04,2,2,1,0,0,0,0,0,Mural,0,0,0,0,0,0,26.662288,,Batch20220908,Mural


In [91]:
categories = train_anndata.obs["sample_id"].cat.remove_unused_categories()

In [113]:
train_anndata.obs["cell_labels"].cat.categories[[7]]

Index(['Mural'], dtype='object')

In [109]:
train_anndata.obs["cell_labels"].cat.categories

Index(['B', 'BnT', 'CD4', 'CD8', 'DC', 'HLADR', 'MacCD163', 'Mural', 'NK',
       'Neutrophil', 'Treg', 'Tumor', 'pDC', 'plasma'],
      dtype='object')

In [117]:
(train_anndata.obs["cell_labels"].cat.codes == 5).sum()

3699

In [116]:
train_anndata[train_anndata.obs["cell_labels"].cat.codes == 5]

View of AnnData object with n_obs × n_vars = 3699 × 40
    obs: 'image', 'sample_id', 'ObjectNumber', 'Pos_X', 'Pos_Y', 'area', 'major_axis_length', 'minor_axis_length', 'eccentricity', 'width_px', 'height_px', 'acquisition_id', 'SlideId', 'Study', 'Box.Description', 'Position', 'SampleId', 'Indication', 'BatchId', 'SubBatchId', 'ROI', 'ROIonSlide', 'includeImage', 'flag_no_cells', 'flag_no_ROI', 'flag_total_area', 'flag_percent_covered', 'small_cell', 'celltypes', 'flag_tumor', 'PD1_pos', 'Ki67_pos', 'cleavedPARP_pos', 'GrzB_pos', 'tumor_patches', 'distToCells', 'CD20_patches', 'Batch', 'cell_labels'
    var: 'channel', 'use_channel', 'marker'
    layers: 'counts', 'exprs'

In [92]:
kfold = KFold(n_splits=5, shuffle=True, random_state=124)
split = kfold.split(categories.cat.categories)
for train, test in split:
    print(train, test)

[  0   2   3   4   5   6   7   8   9  11  12  14  15  16  17  18  19  20
  21  22  25  27  28  30  31  33  35  36  37  38  39  40  41  42  43  45
  46  47  48  49  50  52  53  54  55  57  58  59  60  62  63  64  65  66
  67  69  70  71  72  73  74  75  77  78  79  81  85  86  88  91  92  93
  94  95  96  97  98  99 100 101 102 103 104 105 106 107 108 109 110 111
 113 114 116 118 119 120 121 122 123 124] [  1  10  13  23  24  26  29  32  34  44  51  56  61  68  76  80  82  83
  84  87  89  90 112 115 117]
[  1   2   4   7   8   9  10  12  13  14  16  17  18  20  21  23  24  25
  26  28  29  30  31  32  33  34  36  37  38  39  40  41  42  43  44  46
  48  49  51  52  53  54  55  56  57  58  60  61  63  64  65  66  67  68
  69  70  71  72  73  74  75  76  77  78  79  80  81  82  83  84  85  86
  87  88  89  90  91  93  94  95 101 102 103 104 105 107 108 109 111 112
 113 114 115 116 117 118 119 120 123 124] [  0   3   5   6  11  15  19  22  27  35  45  47  50  59  62  92  96  97
  98  99 1

In [44]:
train_anndata.obs["sample_id"].isin(categories.cat.categories[[5]]).sum()

563

In [14]:
X_train.shape

(214843, 40)

In [15]:
Y_train.shape

(214843,)

In [19]:
Y_test

array([ 6,  5,  6, ..., 11, 11, 11])

In [20]:
scaler = MinMaxScaler()
scaler.fit(X_train)
X_train_scaled = scaler.transform(X_train)

In [34]:
clf = MLPClassifier(solver='adam', alpha=1e-5, hidden_layer_sizes=(40*4, 40*4), random_state=1, early_stopping=True, verbose=True)

In [35]:
clf.fit(X_train, Y_train)

Iteration 1, loss = 0.30038631
Validation score: 0.928159
Iteration 2, loss = 0.18615314
Validation score: 0.943401
Iteration 3, loss = 0.16284684
Validation score: 0.945402
Iteration 4, loss = 0.14703711
Validation score: 0.945812
Iteration 5, loss = 0.13729612
Validation score: 0.950771
Iteration 6, loss = 0.12874505
Validation score: 0.952364
Iteration 7, loss = 0.12157224
Validation score: 0.953547
Iteration 8, loss = 0.11701740
Validation score: 0.954639
Iteration 9, loss = 0.11320872
Validation score: 0.958551
Iteration 10, loss = 0.10887476
Validation score: 0.954366
Iteration 11, loss = 0.10530612
Validation score: 0.958096
Iteration 12, loss = 0.10347686
Validation score: 0.955912
Iteration 13, loss = 0.09943430
Validation score: 0.953137
Iteration 14, loss = 0.09661842
Validation score: 0.954229
Iteration 15, loss = 0.09481377
Validation score: 0.957869
Iteration 16, loss = 0.09259072
Validation score: 0.958824
Iteration 17, loss = 0.08925547
Validation score: 0.957368
Iterat

In [100]:
predicted = clf.predict(X_test)

In [101]:
(Y_test == predicted).sum() / len(predicted)

0.9737891049643677

In [103]:
clf.coefs_

[array([[ 0.17147431, -0.16371339, -1.04906294, ...,  0.32801828,
          0.30106442,  0.09993387],
        [-0.17967195,  0.03298444, -0.02444135, ..., -0.11306175,
         -0.03773455,  0.18553398],
        [ 0.06444886,  0.10068168, -0.36658522, ...,  0.41930423,
         -0.553215  , -0.05980344],
        ...,
        [-0.08440868,  0.04860083, -0.0693246 , ..., -0.29383697,
         -0.25772943, -0.1009417 ],
        [ 0.15811916, -0.02755875, -0.03094263, ..., -0.00918234,
         -0.114665  ,  0.11175534],
        [-0.09782374,  0.07394829,  0.14752271, ...,  0.01209931,
          0.17827707,  0.02527083]]),
 array([[-5.88809464e-02,  9.06510206e-02,  9.57129902e-03, ...,
         -1.05299900e-01, -4.68825141e-02, -1.02326726e-51],
        [ 9.49199145e-02, -2.20206470e-01, -5.66269247e-02, ...,
         -1.61293253e-01,  2.17350610e-01,  4.51008696e-51],
        [-2.39378115e-01, -2.04082290e-01, -1.09272528e-01, ...,
          1.28184324e-02,  5.02691187e-02, -3.15897865e-

In [112]:
clf.intercepts_[2]

array([ 0.07641568, -0.08381112,  0.11533853, -0.02865035, -0.17963673,
        0.10243138,  0.28469385,  0.25019946, -0.42172778, -0.66783937,
       -0.20516517, -0.0512676 , -0.64340007, -0.07672781])

In [30]:
best_xgb_param_grid = {
    "n_estimators": [30, 40, 50],
    "max_depth": [3],
    "learning_rate": [0.3],
    'objective': ['multi:softmax']
}

xgb_grid_search = GridSearchCV(XGBClassifier(), param_grid=best_xgb_param_grid, cv=5, n_jobs=-1, return_train_score=True, verbose=3)

In [31]:
xgb_grid_search.fit(X_train_scaled, Y_train)

Fitting 5 folds for each of 3 candidates, totalling 15 fits


In [None]:
pd.DataFrame(xgb_grid_search.cv_results_)

Unnamed: 0,mean_fit_time,std_fit_time,mean_score_time,std_score_time,param_learning_rate,param_max_depth,param_n_estimators,param_objective,params,split0_test_score,split1_test_score,split2_test_score,split3_test_score,split4_test_score,mean_test_score,std_test_score,rank_test_score,split0_train_score,split1_train_score,split2_train_score,split3_train_score,split4_train_score,mean_train_score,std_train_score
0,51.62328,0.936762,0.629865,0.024305,0.3,3,30,multi:softmax,"{'learning_rate': 0.3, 'max_depth': 3, 'n_esti...",0.913702,0.948879,0.951434,0.947971,0.799738,0.912345,0.05799,2,0.958793,0.953656,0.953509,0.954306,0.962652,0.956583,0.003603
1,66.9818,3.005664,0.813566,0.056834,0.3,3,40,multi:softmax,"{'learning_rate': 0.3, 'max_depth': 3, 'n_esti...",0.912414,0.950737,0.950399,0.949174,0.800625,0.91267,0.057895,1,0.963237,0.957975,0.958286,0.95912,0.9666,0.961044,0.003359
2,51.672913,1.078144,0.789378,0.11042,0.1,3,30,multi:softmax,"{'learning_rate': 0.1, 'max_depth': 3, 'n_esti...",0.924069,0.939187,0.945585,0.941974,0.788399,0.907843,0.06017,4,0.944154,0.939113,0.93849,0.9402,0.953477,0.943087,0.005556
3,46.893433,7.715698,0.527488,0.104851,0.1,3,40,multi:softmax,"{'learning_rate': 0.1, 'max_depth': 3, 'n_esti...",0.922211,0.941467,0.948794,0.943241,0.796296,0.910402,0.057753,3,0.947886,0.943294,0.942254,0.943832,0.956043,0.946662,0.005065


In [89]:
class MLP(nn.Module):
    def __init__(self):
      super().__init__()
      self.layers = nn.Sequential(
        nn.Linear(40, 40*4),
        nn.ReLU(),
        nn.Linear(40*4, 40*4),
        nn.ReLU(),
        nn.Linear(40*4, 14),
        # nn.Softmax(dim=1)
      )
    def forward(self, input):
       return self.layers(input)

In [90]:
# train config

max_epochs=42
device = torch.device("cpu")

In [91]:
class CustomDataset(torch.utils.data.Dataset):
    def __init__(self, X_train, Y_train):
        self.X_train = X_train
        self.Y_train = Y_train
        
    def __len__(self):
        return len(self.X_train)
    
    def __getitem__(self, idx):
        return self.X_train[idx], self.Y_train[idx]


x_train, X_val, y_train, y_val = train_test_split(torch.tensor(X_train).float(), torch.tensor(Y_train).long(), test_size=0.1, random_state=42)

# Create CustomDataset instance
train_dataset = CustomDataset(x_train, y_train)
# val_dataset = CustomDataset(X_val, y_val)

# Create DataLoader
batch_size = 200
train_data_loader = torch.utils.data.DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
# val_data_loader = torch.utils.data.DataLoader(val_dataset, batch_size=batch_size, shuffle=False)

In [92]:
def compute_accuracy(model, X_val, y_val):
    model.eval()  # Set the model to evaluation mode

    with torch.no_grad():  # No need to compute gradients during validation
        pred = model(X_val)
        correct = (torch.argmax(pred, dim=1) == y_val).float().sum()
        accuracy = correct / len(y_val)

    return accuracy

In [164]:
model = MLP()
optimizer = optim.Adam(model.parameters(), lr=0.001, weight_decay=0.0001)
criterion = nn.CrossEntropyLoss()

model.train()
start = time.time()

best_val_acc = 0.
best_model_weights = None

train_losses = []
val_acc = []

for epoch in range(max_epochs):

    total_loss = 0.
    total_acc = 0.
    printevery = 5

    print("\r   %dm: epoch %d [%s] %d%%  loss = %s" %\
    ((time.time() - start)//60, epoch + 1, "".join(' '*20), 0, '...'), end='')


    for i, (input, targets) in enumerate(train_data_loader):
        input, targets = input.to(device), targets.to(device)

        optimizer.zero_grad()
        pred = model(input)
        loss = criterion(pred, targets)

        loss.backward()
        optimizer.step()

        total_loss += loss.item()

        if (i + 1) % printevery == 0:
                p = int(100 * (i + 1) / len(train_data_loader))
                avg_loss = total_loss / printevery

                print("\r   %dm: epoch %d [%s%s] %d%% train loss = %.10f" %\
                ((time.time() - start)//60, epoch + 1, "".join('#'*(p//5)), "".join(' '*(20-(p//5))), p, avg_loss), end='')

                train_losses.append(avg_loss)
                total_loss = 0.
                total_acc = 0.

    val_accuracy = compute_accuracy(model, X_val, y_val)
    val_acc.append(val_accuracy)
    if best_val_acc < val_accuracy:
        best_model_weights = model.state_dict()

    print("\r%dm: epoch %d [%s%s] %d%% \nepoch %d complete, val acc = %.10f" %\
    ((time.time() - start)//60, epoch + 1, "".join('#'*(100//5)), "".join(' '*(20-(100//5))), 100, epoch + 1, val_accuracy))

model.load_state_dict(best_model_weights)

0m: epoch 1 [####################] 100% % train loss = 0.2244943976
epoch 1 complete, val acc = 0.9275336266
0m: epoch 2 [####################] 100% % train loss = 0.1652928293
epoch 2 complete, val acc = 0.9398837686
0m: epoch 3 [####################] 100% % train loss = 0.1637830272
epoch 3 complete, val acc = 0.9461496472
0m: epoch 4 [####################] 100% % train loss = 0.1457840338
epoch 4 complete, val acc = 0.9495096207
1m: epoch 5 [####################] 100% % train loss = 0.1404485077
epoch 5 complete, val acc = 0.9501453042
1m: epoch 6 [####################] 100% % train loss = 0.1381207392
epoch 6 complete, val acc = 0.9407464862
1m: epoch 7 [####################] 100% % train loss = 0.1150693372
epoch 7 complete, val acc = 0.9499182701
1m: epoch 8 [####################] 100% % train loss = 0.1185332283
epoch 8 complete, val acc = 0.9496458173
2m: epoch 9 [####################] 100% % train loss = 0.1132715866
epoch 9 complete, val acc = 0.9527788162
2m: epoch 10 [#####

KeyboardInterrupt: 

In [94]:
model.load_state_dict(best_model_weights)

<All keys matched successfully>

In [165]:
compute_accuracy(model, torch.tensor(X_test).float(), torch.tensor(Y_test).long())

tensor(0.9376)

In [154]:
model_from_sklearn = MLP()

In [155]:
state_dict = model_from_sklearn.state_dict()

In [156]:
torch.tensor(clf.coefs_[0]).float().shape

torch.Size([40, 160])

In [157]:
torch.tensor(clf.intercepts_[0]).float().shape

torch.Size([160])

In [158]:
state_dict["layers.0.weight"].shape

torch.Size([160, 40])

In [159]:
state_dict["layers.0.bias"].shape

torch.Size([160])

In [160]:
for i in range(len(clf.coefs_)):
    state_dict[f"layers.{2*i}.weight"] = torch.tensor(clf.coefs_[i]).float().t()
for i in range(len(clf.intercepts_)):
    state_dict[f"layers.{2*i}.bias"] = torch.tensor(clf.intercepts_[i]).float()

In [161]:
model_from_sklearn.load_state_dict(state_dict)

<All keys matched successfully>

In [162]:
compute_accuracy(model_from_sklearn, torch.tensor(X_test).float(), torch.tensor(Y_test).long())

tensor(0.9738)