# IMPORTS

In [4]:
import torch
import numpy as np
import pandas as pd
from pathlib import Path
import matplotlib.pyplot as plt
from torchvision import datasets
from sklearn.metrics import classification_report
from sklearn.metrics import ConfusionMatrixDisplay

import sys
sys.path.append("../")
from data import get_data_transform
from ensembling.ensembling import Averaging, Voting, Stacking
from model import CNN, ResNet, Inception, VGG, ViT, load_from_exp

from sklearn.svm import SVC
from sklearn.ensemble import RandomForestClassifier
from sklearn.linear_model import LogisticRegression

In [7]:
resnet = load_from_exp(Path("../../EXPERIMENTS/resnet_2022-11-27-00-56"))
vgg = load_from_exp(Path("../../EXPERIMENTS/vgg_2022-11-27-00-55"))
inception = load_from_exp(Path("../../EXPERIMENTS/inception_2022-11-27-00-54"))
bag_of_models = [resnet, vgg, inception]



# DATA

In [9]:
data_transforms = get_data_transform(image_size=299, data_augmentation=0)

train_loader = torch.utils.data.DataLoader(
    datasets.ImageFolder("../../data/train_images", transform=data_transforms),
    batch_size=64,
    shuffle=False,
)
val_loader = torch.utils.data.DataLoader(
    datasets.ImageFolder("../../data/val_images", transform=data_transforms),
    batch_size=64,
    shuffle=False,
)

## EXTRACT TRAIN AND TEST SET

### TRAIN

In [10]:
targets = []
for i, (X_train, y_train) in enumerate(train_loader):
    print("Batch", i)
    for j, model in enumerate(bag_of_models):
        if j == 0:
            X_bag_i = model(X_train).detach().numpy()
        else:
            X_bag_i = np.hstack([X_bag_i, model(X_train).detach().numpy()])
    if i == 0:
        X_bag = X_bag_i
    else:
        X_bag = np.vstack([X_bag, X_bag_i])
    targets += y_train.detach().numpy().tolist()

pd.Series(targets).to_csv("y_train.csv")
pd.DataFrame(X_bag).to_csv("X_bag_train.csv")

Batch 0
Batch 1
Batch 2
Batch 3
Batch 4
Batch 5
Batch 6
Batch 7
Batch 8
Batch 9
Batch 10
Batch 11
Batch 12
Batch 13
Batch 14
Batch 15
Batch 16


### TEST

In [11]:
targets = []
for i, (X_val, y_val) in enumerate(val_loader):
    print("Batch", i)
    for j, model in enumerate(bag_of_models):
        if j == 0:
            X_bag_i = model(X_val).detach().numpy()
        else:
            X_bag_i = np.hstack([X_bag_i, model(X_val).detach().numpy()])
    if i == 0:
        X_bag = X_bag_i
    else:
        X_bag = np.vstack([X_bag, X_bag_i])
    targets += y_val.detach().numpy().tolist()

pd.Series(targets).to_csv("y_val.csv")
pd.DataFrame(X_bag).to_csv("X_bag_val.csv")

Batch 0
Batch 1


In [18]:
X_train = pd.read_csv("./X_bag_train.csv")
y_train = pd.read_csv("./y_train.csv", index_col=0)
X_val = pd.read_csv("./X_bag_val.csv")
y_val = pd.read_csv("./y_val.csv", index_col=0)

In [26]:
model = RandomForestClassifier(n_estimators=100)
model = LogisticRegression()
model.fit(X_train, y_train)

  y = column_or_1d(y, warn=True)
STOP: TOTAL NO. of ITERATIONS REACHED LIMIT.

Increase the number of iterations (max_iter) or scale the data as shown in:
    https://scikit-learn.org/stable/modules/preprocessing.html
Please also refer to the documentation for alternative solver options:
    https://scikit-learn.org/stable/modules/linear_model.html#logistic-regression
  n_iter_i = _check_optimize_result(


In [27]:
y_pred = model.predict(X_val)

In [63]:
import torch

In [65]:
avg = np.zeros((3, len(X_val), 20))
for i, (start, end) in enumerate([[0, 20], [20, 40], [40, 60]]):
    avg[i] = torch.nn.Softmax()(torch.tensor(X_val.iloc[:, start:end].values)).numpy()

  avg[i] = torch.nn.Softmax()(torch.tensor(X_val.iloc[:, start:end].values)).numpy()


In [68]:
avg

array([[[4.47525627e-02, 7.92900288e-02, 6.27186888e-02, ...,
         9.19432566e-02, 6.17457206e-02, 4.26319836e-02],
        [9.72548457e-02, 1.05073091e-01, 3.96296826e-02, ...,
         4.73358567e-02, 7.45662959e-02, 3.87891651e-02],
        [2.19736428e-01, 3.69361628e-02, 8.42389389e-02, ...,
         7.95268604e-02, 3.86921908e-02, 1.65168690e-02],
        ...,
        [1.00000000e+00, 3.53669509e-44, 3.92462620e-44, ...,
         2.29108117e-44, 6.48285949e-44, 5.10785613e-44],
        [1.00000000e+00, 7.91252055e-45, 8.60550767e-45, ...,
         6.12776072e-45, 1.64418002e-44, 1.68385969e-44],
        [1.00000000e+00, 3.11883391e-45, 5.50758584e-45, ...,
         2.96728052e-45, 6.49201860e-45, 6.10038751e-45]],

       [[1.51209257e-02, 1.25671994e-01, 1.28954869e-01, ...,
         9.77157596e-02, 2.25894208e-02, 1.50090632e-02],
        [7.88110332e-03, 2.58391896e-01, 1.03302255e-01, ...,
         5.32158992e-02, 3.76284905e-02, 2.65562873e-02],
        [1.12771904e-02, 

In [66]:
print(
    classification_report(
        y_true=y_val,
        y_pred=avg.mean(axis=0).argmax(axis=1)
    )
)

              precision    recall  f1-score   support

           0       0.00      0.00      0.00       2.0
           1       0.00      0.00      0.00       4.0
           2       0.00      0.00      0.00       7.0
           3       0.00      0.00      0.00       3.0
           4       0.00      0.00      0.00       2.0
           5       0.00      0.00      0.00       2.0
           6       0.00      0.00      0.00       3.0
           7       0.00      0.00      0.00       4.0
           8       0.00      0.00      0.00       6.0
           9       0.00      0.00      0.00       5.0
          10       0.00      0.00      0.00       8.0
          11       0.00      0.00      0.00       6.0
          12       0.00      0.00      0.00       6.0
          13       0.00      0.00      0.00       8.0
          14       0.00      0.00      0.00       7.0
          15       0.00      0.00      0.00       7.0
          16       0.00      0.00      0.00       6.0
          17       0.00    

  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))


In [30]:
print(
    classification_report(
        y_true=y_val,
        y_pred=y_pred
    )
)

              precision    recall  f1-score   support

           0       0.33      1.00      0.50         2
           1       0.15      0.50      0.24         4
           2       0.56      0.71      0.63         7
           3       0.25      0.33      0.29         3
           4       0.67      1.00      0.80         2
           5       1.00      1.00      1.00         2
           6       0.21      1.00      0.35         3
           7       0.40      0.50      0.44         4
           8       1.00      0.17      0.29         6
           9       0.50      0.20      0.29         5
          10       0.89      1.00      0.94         8
          11       1.00      0.50      0.67         6
          12       0.80      0.67      0.73         6
          13       1.00      0.25      0.40         8
          14       0.86      0.86      0.86         7
          15       0.50      0.43      0.46         7
          16       0.00      0.00      0.00         6
          17       1.00    

  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
