In [1]:
import torchvision
import torch.nn as nn
import torch
import torch.nn.functional as F
from torch.utils.data import DataLoader, Dataset
from torchvision import transforms,models,datasets
import matplotlib.pyplot as plt
from collections.abc import Iterable
import numpy as np
from torch import optim
from tqdm import tqdm
DEVICE = 'cuda:0' if torch.cuda.is_available() else 'cpu'
FOLD = 1
MAG = 40
import cv2, numpy as np, pandas as pd
import matplotlib.pyplot as plt
from sklearn.metrics import f1_score, precision_score, recall_score, accuracy_score
%matplotlib inline
from sklearn.exceptions import UndefinedMetricWarning
import warnings
warnings.filterwarnings(action='ignore', category=UndefinedMetricWarning)
import os



In [11]:
def f1_macro(y_true, y_pred):
    return f1_score(y_true, y_pred, average='macro')

In [2]:
import pandas as pd

df = pd.read_csv("Folds.csv")

In [3]:
from torch.utils.data import DataLoader, Dataset
class imagesXNN_train(Dataset):
  def __init__(self, paths):
    self.fpaths = paths

  def __len__(self): return len(self.fpaths)

  def __getitem__(self, ix):
    f = self.fpaths[ix]
    res = torch.tensor(int("malignant" in f)).type(torch.LongTensor)
    im = cv2.imread(f)[:,:,::-1]
    return transforms.Compose([transforms.Resize((300, 300)), 
                               transforms.RandomHorizontalFlip(), 
                               transforms.RandomVerticalFlip()])(torch.tensor(im/255).permute(2,0,1).float()), res

In [4]:
from torch.utils.data import DataLoader, Dataset
class imagesXNN_test(Dataset):
  def __init__(self, paths):
    self.fpaths = paths

  def __len__(self): return len(self.fpaths)

  def __getitem__(self, ix):
    f = self.fpaths[ix]
    res = torch.tensor(int("malignant" in f)).type(torch.LongTensor)
    im = cv2.imread(f)[:,:,::-1]
    return transforms.Compose([transforms.Resize((300, 300))])(torch.tensor(im/255).permute(2,0,1).float()), res

In [5]:
def train_epoch(model, loader, criterion, optim, local_metric_fn: Iterable):
    local_metric = [0] * len(local_metric_fn)
    local_loss = 0

    model.train()
    for step, (images, labels) in tqdm(enumerate(loader), total=312):
        images = images.to(DEVICE)
        labels = labels.to(DEVICE)
        outputs = model(images)
        local_loss = criterion(outputs, labels)
        local_loss.backward()
        optim.step()
        optim.zero_grad()
        local_loss += local_loss.item()
        for i, metric in enumerate(local_metric_fn):
            local_metric[i] += metric(outputs.cpu().detach().numpy().argmax(axis=1), labels.cpu().numpy())
    return local_loss/(step + 1), [metric/(step + 1) for metric in local_metric]

In [6]:
def valid_epoch(model, loader, criterion, local_metric_fn: Iterable):
    local_metric = [0] * len(local_metric_fn)
    local_loss = 0

    model.eval()
    for step, (images, labels) in tqdm(enumerate(loader), total=187):
        images = images.to(DEVICE)
        labels = labels.to(DEVICE)
        with torch.no_grad():
            outputs = model(images)
            local_loss = criterion(outputs, labels)
            local_loss += local_loss.item()
            for i, metric in enumerate(local_metric_fn):
                local_metric[i] += metric(outputs.cpu().detach().numpy().argmax(axis=1), labels.cpu().numpy())
    return local_loss/(step + 1), [metric/(step + 1) for metric in local_metric]

In [7]:

def train(model, num_epochs, train_dl, valid_dl, loss_fn, optimizer, local_metric_fn, i=0, tolerance = 0.5):
    loss_hist_train = [0] * num_epochs
    accuracy_hist_train = [0] * num_epochs
    loss_hist_valid = [0] * num_epochs
    accuracy_hist_valid = [0] * num_epochs

    for epoch in range(num_epochs):

        loss_tr_ep, metrics_tr_ep = train_epoch(model, train_dl, loss_fn, optimizer, local_metric_fn)

        loss_hist_train[epoch] = loss_tr_ep
        accuracy_hist_train[epoch] = metrics_tr_ep
        
        loss_ep, metrics_ep = valid_epoch(model, valid_dl, loss_fn, local_metric_fn)
        
        loss_hist_valid[epoch] = loss_ep
        accuracy_hist_valid[epoch] = metrics_ep
        if epoch == 0:
            bestAccVal = accuracy_hist_valid[epoch][2]
            torch.save(model.state_dict(), f'6models_fulldata_{i}_{FOLD}.pt')
        else:
            if accuracy_hist_valid[epoch][2] > max(bestAccVal, tolerance):
                bestAccVal = accuracy_hist_valid[epoch][2]
                torch.save(model.state_dict(), f'6models_fulldata_{i}_{FOLD}.pt')

        print(f'Epoch {epoch+1} f1: {accuracy_hist_train[epoch][3]:{1}.{5}} val_f1: {accuracy_hist_valid[epoch][3]:{1}.{5}}')
    if os.path.exists(f'6models_fulldata_{i}_{FOLD}.pt'):
        model.load_state_dict(torch.load(f'6models_fulldata_{i}_{FOLD}.pt'))
    return loss_hist_train, loss_hist_valid, accuracy_hist_train, accuracy_hist_valid, bestAccVal

In [8]:
fold1 = df.loc[(df.fold==FOLD) & (df.mag==MAG) & (df.grp=="train")]["filename"].sample(frac=0.78).values
fold2 = df.loc[(df.fold==FOLD) & (df.mag==MAG) & (df.grp=="train")]["filename"].sample(frac=0.78).values
train_full = df.loc[(df.fold==FOLD) & (df.mag==MAG) & (df.grp=="train")]["filename"].values

In [9]:
trn_ds_s1 = imagesXNN_train(fold1)
trn_dl_s1 = DataLoader(trn_ds_s1, batch_size=2, shuffle=True, drop_last = True, num_workers = 0, pin_memory=True)

trn_ds_s2 = imagesXNN_train(fold2)
trn_dl_s2 = DataLoader(trn_ds_s2, batch_size=2, shuffle=True, drop_last = True, num_workers = 0, pin_memory=True)

test_40 = imagesXNN_test(df.loc[(df.fold==FOLD) & (df.mag==MAG) & (df.grp=="test")]["filename"].values)
test_dl = DataLoader(test_40, batch_size=4, shuffle=False, drop_last = True, num_workers = 0)

trn_ds = imagesXNN_train(train_full)
trn_dl = DataLoader(trn_ds, batch_size=4, shuffle=True, drop_last = True, num_workers = 0)

### Train single classifier

In [None]:
model = torchvision.models.efficientnet_v2_l(weights=models.EfficientNet_V2_L_Weights.DEFAULT)
modelname="efficientnet_v2_l"
for layer in model.parameters():
    layer.requires_grad = False
model.classifier = nn.Sequential(nn.Linear(1280, 960), nn.ReLU(), nn.Dropout(0.2), nn.Linear(960, 540), nn.ReLU(), nn.Linear(540, 320), nn.ReLU(), nn.Linear(320, 100), nn.ReLU(), nn.Linear(100, 2), nn.Softmax(dim=1))
loss_fn = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr= 1e-5)
model = model.to(DEVICE)
hst = train(model, 15, trn_dl, test_dl, loss_fn, optimizer,  [recall_score, precision_score, accuracy_score, f1_macro], i=11)

In [None]:
model = torchvision.models.efficientnet_b7(weights=models.EfficientNet_B7_Weights.DEFAULT)
modelname="efficientnet_v1_b7"
model.classifier = nn.Sequential(nn.Linear(2560, 1560), nn.ReLU(), nn.Dropout(0.2), nn.Linear(1560, 960), nn.ReLU(), nn.Linear(960, 540), nn.ReLU(), nn.Linear(540, 320), nn.ReLU(), nn.Linear(320, 2), nn.Softmax(dim=1))
loss_fn = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr= 1e-5)
model = model.to(DEVICE)
hst_v1b7 = train(model, 15, trn_dl, test_dl, loss_fn, optimizer,  [recall_score, precision_score, accuracy_score, f1_macro], i=10)

### Initialize ensemble models

In [10]:
model_constructor = [torchvision.models.efficientnet_v2_l(weights=torchvision.models.EfficientNet_V2_L_Weights.DEFAULT),
                     torchvision.models.efficientnet_v2_m(weights=torchvision.models.EfficientNet_V2_M_Weights.DEFAULT),
                     torchvision.models.efficientnet_v2_s(weights=torchvision.models.EfficientNet_V2_S_Weights.DEFAULT),
                     torchvision.models.efficientnet_v2_l(),
                     torchvision.models.efficientnet_v2_m(),
                     torchvision.models.efficientnet_v2_s()
                     ]



def generate_model(i):
    model = model_constructor[i]
    model.classifier = nn.Sequential(nn.Linear(1280, 960), nn.Dropout(0.2, inplace=True), nn.ReLU(), nn.Linear(960, 480), nn.ReLU(), nn.Linear(480, 2))
    loss = nn.CrossEntropyLoss()
    optimizer = torch.optim.Adam(model.parameters(), lr= 1e-5)
    model.to(DEVICE)

    return {
        'model': model,
        'centropy': loss,
        'optimizer': optimizer,
        }


In [12]:
models = [generate_model(j) for j in range(6)]

In [43]:
class MyEnsemble(nn.Module):
    def __init__(self, models):
        super(MyEnsemble, self).__init__()
        self.models = models

    def forward(self, x):
        preds = [models["model"](x).detach().cpu().numpy() for models in self.models]
        
        preds_w = np.array(preds)
        preds_w = nn.Softmax(dim=2)(torch.Tensor(preds_w)).numpy()
        maxs = preds_w.max(axis=2)
        maxs = nn.Softmax(dim=0)(torch.Tensor(maxs)).numpy()
        amaxs = preds_w.argmax(axis=2) * 2 - 1
        pred_w = (np.sum(maxs * amaxs, axis=0) > 0) * 1
        

        pred = np.array(preds)
        y = pred.argmax(axis=2).sum(axis=0)/float(len(self.models))
        return (y > 0.5)*1, pred_w

### Load Checkpoints

In [None]:
for i, model in enumerate(models):
    model["model"].load_state_dict(torch.load(f'6models_fulldata_{i+1}_{FOLD}.pt'))

### Train Ensemble

In [13]:
epochs = [5,5,5,10,10,10]
i = 0
hist = []
local_metric_fn = [recall_score, precision_score, accuracy_score, f1_macro]
local_metric = [0] * len(local_metric_fn)
for mdict, epoch in zip(models, epochs):
    i += 1
    hist.append(train(model = mdict["model"], 
                        num_epochs= epoch, 
                        train_dl = trn_dl, 
                        valid_dl=test_dl, 
                        loss_fn=mdict["centropy"], 
                        optimizer=mdict["optimizer"], 
                        local_metric_fn=local_metric_fn,
                        i=i)) 
scores = [i[4] for i in hist]


100%|██████████| 312/312 [01:25<00:00,  3.63it/s]
 99%|█████████▉| 186/187 [00:19<00:00,  9.54it/s]


Epoch 1 f1: 0.64005 val_f1: 0.77819


100%|██████████| 312/312 [01:25<00:00,  3.64it/s]
 99%|█████████▉| 186/187 [00:19<00:00,  9.53it/s]


Epoch 2 f1: 0.82407 val_f1: 0.75981


100%|██████████| 312/312 [01:30<00:00,  3.46it/s]
 99%|█████████▉| 186/187 [00:19<00:00,  9.73it/s]


Epoch 3 f1: 0.85466 val_f1: 0.80963


100%|██████████| 312/312 [01:27<00:00,  3.58it/s]
 99%|█████████▉| 186/187 [00:19<00:00,  9.49it/s]


Epoch 4 f1: 0.90527 val_f1: 0.81859


100%|██████████| 312/312 [01:27<00:00,  3.58it/s]
 99%|█████████▉| 186/187 [00:19<00:00,  9.75it/s]


Epoch 5 f1: 0.91752 val_f1: 0.83641


100%|██████████| 312/312 [00:57<00:00,  5.44it/s]
 99%|█████████▉| 186/187 [00:15<00:00, 12.09it/s]


Epoch 1 f1: 0.54231 val_f1: 0.67143


100%|██████████| 312/312 [00:59<00:00,  5.24it/s]
 99%|█████████▉| 186/187 [00:15<00:00, 11.82it/s]


Epoch 2 f1: 0.73094 val_f1: 0.77696


100%|██████████| 312/312 [00:59<00:00,  5.25it/s]
 99%|█████████▉| 186/187 [00:15<00:00, 11.82it/s]


Epoch 3 f1: 0.81564 val_f1: 0.82156


100%|██████████| 312/312 [01:01<00:00,  5.07it/s]
 99%|█████████▉| 186/187 [00:15<00:00, 11.88it/s]


Epoch 4 f1: 0.82769 val_f1: 0.82125


100%|██████████| 312/312 [00:57<00:00,  5.41it/s]
 99%|█████████▉| 186/187 [00:15<00:00, 11.87it/s]


Epoch 5 f1: 0.86969 val_f1: 0.77437


100%|██████████| 312/312 [00:46<00:00,  6.69it/s]
 99%|█████████▉| 186/187 [00:13<00:00, 14.11it/s]


Epoch 1 f1: 0.51978 val_f1: 0.66183


100%|██████████| 312/312 [00:45<00:00,  6.82it/s]
 99%|█████████▉| 186/187 [00:13<00:00, 13.86it/s]


Epoch 2 f1: 0.6409 val_f1: 0.65366


100%|██████████| 312/312 [00:45<00:00,  6.84it/s]
 99%|█████████▉| 186/187 [00:13<00:00, 13.89it/s]


Epoch 3 f1: 0.79092 val_f1: 0.66426


100%|██████████| 312/312 [00:43<00:00,  7.16it/s]
 99%|█████████▉| 186/187 [00:13<00:00, 13.76it/s]


Epoch 4 f1: 0.85919 val_f1: 0.68131


100%|██████████| 312/312 [00:43<00:00,  7.21it/s]
 99%|█████████▉| 186/187 [00:13<00:00, 13.90it/s]


Epoch 5 f1: 0.86021 val_f1: 0.76544


100%|██████████| 312/312 [01:25<00:00,  3.63it/s]
 99%|█████████▉| 186/187 [00:19<00:00,  9.59it/s]


Epoch 1 f1: 0.52022 val_f1: 0.65699


100%|██████████| 312/312 [01:25<00:00,  3.66it/s]
 99%|█████████▉| 186/187 [00:19<00:00,  9.68it/s]


Epoch 2 f1: 0.53567 val_f1: 0.67437


100%|██████████| 312/312 [01:33<00:00,  3.35it/s]
 99%|█████████▉| 186/187 [00:19<00:00,  9.76it/s]


Epoch 3 f1: 0.63443 val_f1: 0.71843


100%|██████████| 312/312 [01:25<00:00,  3.64it/s]
 99%|█████████▉| 186/187 [00:19<00:00,  9.68it/s]


Epoch 4 f1: 0.63396 val_f1: 0.69811


100%|██████████| 312/312 [01:24<00:00,  3.68it/s]
 99%|█████████▉| 186/187 [00:19<00:00,  9.51it/s]


Epoch 5 f1: 0.59478 val_f1: 0.6693


100%|██████████| 312/312 [01:25<00:00,  3.66it/s]
 99%|█████████▉| 186/187 [00:19<00:00,  9.46it/s]


Epoch 6 f1: 0.62196 val_f1: 0.63162


100%|██████████| 312/312 [01:26<00:00,  3.62it/s]
 99%|█████████▉| 186/187 [00:19<00:00,  9.64it/s]


Epoch 7 f1: 0.66378 val_f1: 0.67087


100%|██████████| 312/312 [01:26<00:00,  3.62it/s]
 99%|█████████▉| 186/187 [00:19<00:00,  9.75it/s]


Epoch 8 f1: 0.67874 val_f1: 0.68651


100%|██████████| 312/312 [01:25<00:00,  3.63it/s]
 99%|█████████▉| 186/187 [00:19<00:00,  9.43it/s]


Epoch 9 f1: 0.64655 val_f1: 0.67368


100%|██████████| 312/312 [01:25<00:00,  3.63it/s]
 99%|█████████▉| 186/187 [00:19<00:00,  9.66it/s]


Epoch 10 f1: 0.67172 val_f1: 0.74921


100%|██████████| 312/312 [00:59<00:00,  5.26it/s]
 99%|█████████▉| 186/187 [00:15<00:00, 11.65it/s]


Epoch 1 f1: 0.52592 val_f1: 0.65699


100%|██████████| 312/312 [00:58<00:00,  5.33it/s]
 99%|█████████▉| 186/187 [00:15<00:00, 11.76it/s]


Epoch 2 f1: 0.5756 val_f1: 0.66109


100%|██████████| 312/312 [00:58<00:00,  5.29it/s]
 99%|█████████▉| 186/187 [00:15<00:00, 11.76it/s]


Epoch 3 f1: 0.62056 val_f1: 0.72243


100%|██████████| 312/312 [00:58<00:00,  5.36it/s]
 99%|█████████▉| 186/187 [00:15<00:00, 11.76it/s]


Epoch 4 f1: 0.63703 val_f1: 0.66687


100%|██████████| 312/312 [00:58<00:00,  5.33it/s]
 99%|█████████▉| 186/187 [00:15<00:00, 11.84it/s]


Epoch 5 f1: 0.67271 val_f1: 0.72071


100%|██████████| 312/312 [00:58<00:00,  5.34it/s]
 99%|█████████▉| 186/187 [00:15<00:00, 11.81it/s]


Epoch 6 f1: 0.65131 val_f1: 0.69747


100%|██████████| 312/312 [00:59<00:00,  5.24it/s]
 99%|█████████▉| 186/187 [00:15<00:00, 11.67it/s]


Epoch 7 f1: 0.65455 val_f1: 0.71224


100%|██████████| 312/312 [00:58<00:00,  5.34it/s]
 99%|█████████▉| 186/187 [00:15<00:00, 11.88it/s]


Epoch 8 f1: 0.67065 val_f1: 0.66477


100%|██████████| 312/312 [00:57<00:00,  5.40it/s]
 99%|█████████▉| 186/187 [00:15<00:00, 12.13it/s]


Epoch 9 f1: 0.67625 val_f1: 0.67384


100%|██████████| 312/312 [00:57<00:00,  5.39it/s]
 99%|█████████▉| 186/187 [00:15<00:00, 11.80it/s]


Epoch 10 f1: 0.68362 val_f1: 0.72071


100%|██████████| 312/312 [00:44<00:00,  6.98it/s]
 99%|█████████▉| 186/187 [00:13<00:00, 13.98it/s]


Epoch 1 f1: 0.50943 val_f1: 0.65699


100%|██████████| 312/312 [00:44<00:00,  6.97it/s]
 99%|█████████▉| 186/187 [00:13<00:00, 13.50it/s]


Epoch 2 f1: 0.56838 val_f1: 0.72309


100%|██████████| 312/312 [00:44<00:00,  7.04it/s]
 99%|█████████▉| 186/187 [00:13<00:00, 13.37it/s]


Epoch 3 f1: 0.64666 val_f1: 0.72304


100%|██████████| 312/312 [00:44<00:00,  7.04it/s]
 99%|█████████▉| 186/187 [00:13<00:00, 13.31it/s]


Epoch 4 f1: 0.63158 val_f1: 0.71603


100%|██████████| 312/312 [00:44<00:00,  6.95it/s]
 99%|█████████▉| 186/187 [00:13<00:00, 13.37it/s]


Epoch 5 f1: 0.65737 val_f1: 0.71976


100%|██████████| 312/312 [00:43<00:00,  7.10it/s]
 99%|█████████▉| 186/187 [00:13<00:00, 13.36it/s]


Epoch 6 f1: 0.66065 val_f1: 0.69595


100%|██████████| 312/312 [00:45<00:00,  6.88it/s]
 99%|█████████▉| 186/187 [00:14<00:00, 13.27it/s]


Epoch 7 f1: 0.65856 val_f1: 0.6786


100%|██████████| 312/312 [00:45<00:00,  6.84it/s]
 99%|█████████▉| 186/187 [00:14<00:00, 12.94it/s]


Epoch 8 f1: 0.69103 val_f1: 0.65059


100%|██████████| 312/312 [00:44<00:00,  6.97it/s]
 99%|█████████▉| 186/187 [00:14<00:00, 13.21it/s]


Epoch 9 f1: 0.67262 val_f1: 0.58633


100%|██████████| 312/312 [00:43<00:00,  7.10it/s]
 99%|█████████▉| 186/187 [00:13<00:00, 13.43it/s]


Epoch 10 f1: 0.69629 val_f1: 0.66293


### Calculate ensemble output

In [46]:
ensemble = MyEnsemble(models)

local_metric_fn = [recall_score, precision_score, accuracy_score, f1_macro]
local_metric_simple = [0] * len(local_metric_fn)
local_metric_conf = [0] * len(local_metric_fn)
with torch.no_grad():
    for step, (images, labels) in enumerate(test_dl):
        images = images.to(DEVICE) 
        labels = labels.squeeze() 
        outputs = ensemble.forward(images)
        for i, metric in enumerate(local_metric_fn):
            local_metric_simple[i] += metric(outputs[0], labels)
            local_metric_conf[i] += metric(outputs[1], labels)
local_metric_simple = [metric/(step + 1) for metric in local_metric_simple]
local_metric_conf = [metric/(step + 1) for metric in local_metric_conf]
print(local_metric_simple)
print(local_metric_conf)

[0.6559139784946236, 0.5940860215053764, 0.8803763440860215, 0.8337941628264207]
[0.6612903225806451, 0.6223118279569892, 0.8897849462365591, 0.8474142345110084]


In [40]:
for model_hist in range(6):
    print(np.array(hist[i][3]).max(axis=0))

[0.66129032 0.62231183 0.89247312 0.83640553]
[0.66129032 0.62231183 0.89247312 0.83640553]
[0.66129032 0.62231183 0.89247312 0.83640553]
[0.66129032 0.62231183 0.89247312 0.83640553]
[0.66129032 0.62231183 0.89247312 0.83640553]
[0.66129032 0.62231183 0.89247312 0.83640553]


### 3-stacked ensemble model

In [None]:
def train(model, num_epochs, train_dl, valid_dl, loss_fn, optimizer, local_metric_fn, i=0, tolerance = 0.5):
    loss_hist_train = [0] * num_epochs
    accuracy_hist_train = [0] * num_epochs
    loss_hist_valid = [0] * num_epochs
    accuracy_hist_valid = [0] * num_epochs

    for epoch in range(num_epochs):

        loss_tr_ep, metrics_tr_ep = train_epoch(model, train_dl, loss_fn, optimizer, local_metric_fn)

        loss_hist_train[epoch] = loss_tr_ep
        accuracy_hist_train[epoch] = metrics_tr_ep
        
        loss_ep, metrics_ep = valid_epoch(model, valid_dl, loss_fn, local_metric_fn)
        
        loss_hist_valid[epoch] = loss_ep
        accuracy_hist_valid[epoch] = metrics_ep
        if epoch == 0:
            bestAccVal = accuracy_hist_valid[epoch][2]
            torch.save(model.state_dict(), f'ensemble_3b0{i}_{FOLD}.pt')
        else:
            if accuracy_hist_valid[epoch][2] > max(bestAccVal, tolerance):
                bestAccVal = accuracy_hist_valid[epoch][2]
                torch.save(model.state_dict(), f'ensemble_3b0{i}_{FOLD}.pt')

        print(f'Epoch {epoch+1} f1: {accuracy_hist_train[epoch][3]:{1}.{5}} val_f1: {accuracy_hist_valid[epoch][3]:{1}.{5}}')
    if os.path.exists(f'ensemble_3b0{i}_{FOLD}.pt'):
        model.load_state_dict(torch.load(f'ensemble_3b0{i}_{FOLD}.pt'))
    return loss_hist_train, loss_hist_valid, accuracy_hist_train, accuracy_hist_valid, bestAccVal

In [None]:
class ensemble(nn.Module):
    def __init__(self):
        super(ensemble, self).__init__()
        self.model1 = nn.Sequential(torchvision.models.efficientnet_b0(weights=models.EfficientNet_B0_Weights.DEFAULT).features,
                                    torchvision.models.efficientnet_b0(weights=models.EfficientNet_B0_Weights.DEFAULT).avgpool)
        self.model2 = nn.Sequential(torchvision.models.efficientnet_b0(weights=models.EfficientNet_B0_Weights.DEFAULT).features,
                                    torchvision.models.efficientnet_b0(weights=models.EfficientNet_B0_Weights.DEFAULT).avgpool)
        self.model3 = nn.Sequential(torchvision.models.efficientnet_b0(weights=models.EfficientNet_B0_Weights.DEFAULT).features,
                                    torchvision.models.efficientnet_b0(weights=models.EfficientNet_B0_Weights.DEFAULT).avgpool)
        self.model1 = self.model1.to(DEVICE)
        self.model2 = self.model2.to(DEVICE)
        self.model3 = self.model3.to(DEVICE)

    def forward(self, x):
        out1 = self.model1(x)
        out_test = torchvision.models.efficientnet_b0(weights=models.EfficientNet_B0_Weights.DEFAULT).features.to(DEVICE)(x)
        out2 = self.model2(x)
        out3 = self.model3(x)
        cat = torch.concat((out1, out2, out3), dim = 1)
        return nn.Sequential(nn.Flatten(),
                             nn.Linear(1280*3, 960), 
                             nn.ReLU(), 
                             nn.Dropout(0.2), 
                             nn.Linear(960, 540), 
                             nn.ReLU(), 
                             nn.Linear(540, 320), 
                             nn.ReLU(), 
                             nn.Linear(320, 100), 
                             nn.ReLU(), 
                             nn.Linear(100, 2), 
                             nn.Softmax(dim=1)).to(DEVICE)(cat)
        


In [None]:
model = ensemble()
loss_fn = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr= 1e-5)
model = model.to(DEVICE)
hst = train(model, 15, trn_dl, test_dl, loss_fn, optimizer,  [recall_score, precision_score, accuracy_score, f1_macro], i=11)