In [1]:
import torch
import random
import copy
import itertools
import numpy as np
from sklearn.metrics import classification_report
import utils

from LIM_deep_neural_classifier import LIMDeepNeuralClassifier
import dataset_semantic
from trainer import LIMTrainer

In [2]:
utils.fix_random_seeds()

In [3]:
PATH = "basemodel"


X, y = dataset_semantic.generate_data()

In [4]:
print(len(X))

84


In [5]:
print(len(y))
print(len(y[0]))

84
53


In [6]:
embedding_dim = 256
VOCAB_SIZE = 25
LIM = LIMDeepNeuralClassifier(
    hidden_dim=embedding_dim*2, 
    hidden_activation=torch.nn.ReLU(), 
    num_layers=2,
    input_dim=embedding_dim*2,
    n_classes=53,
    embedding_size_and_dim=(VOCAB_SIZE,embedding_dim)
    )
LIM_trainer = LIMTrainer(
    LIM,
    warm_start=True,
    max_iter=1000,
    batch_size=16,
    l2_strength=0.0001,
    n_iter_no_change=10000,
    shuffle_train=False,
    eta=0.001)

In [7]:

_ = LIM_trainer.fit(X, y)

Finished epoch 1000 of 1000; error is 15.918632507324219

In [8]:
preds = LIM_trainer.predict_logits(X).cpu()

In [9]:
print(preds.shape)

torch.Size([84, 53])


In [10]:
print(dataset_semantic.classwise_report(y, preds, classes="all"))

CLASSIFICATION FOR CLASS:pine
              precision    recall  f1-score   support

         0.0       1.00      1.00      1.00        83
         1.0       1.00      1.00      1.00         1

    accuracy                           1.00        84
   macro avg       1.00      1.00      1.00        84
weighted avg       1.00      1.00      1.00        84


CLASSIFICATION FOR CLASS:oak
              precision    recall  f1-score   support

         0.0       1.00      1.00      1.00        83
         1.0       1.00      1.00      1.00         1

    accuracy                           1.00        84
   macro avg       1.00      1.00      1.00        84
weighted avg       1.00      1.00      1.00        84


CLASSIFICATION FOR CLASS:maple
              precision    recall  f1-score   support

         0.0       1.00      1.00      1.00        83
         1.0       1.00      1.00      1.00         1

    accuracy                           1.00        84
   macro avg       1.00      1.00   

In [11]:
torch.save(LIM_trainer.model.state_dict(), PATH)

In [12]:
dataset =dataset_semantic.generate_data_plants_have_roots_and_animals_can_move_and_have_skin()
iit_data = tuple(dataset[2:5])
X_base_train, y_base_train, X_sources_train, y_IIT_train, interventions, blackout_classes = dataset

In [13]:
data = zip(X_base_train, y_base_train, X_sources_train, y_IIT_train, interventions)

def balance(data):
    some_change = []
    no_change = []
    for d in data:
        for index in blackout_classes:w
    

In [14]:
LIM_trainer.model.set_analysis_mode(True)
LIM_trainer.max_iter = 15
LIM_trainer.blackoutclasses = blackout_classes
LIM_trainer.l2_strength = 0.0
LIM_trainer.eta = 0.0001

In [15]:

id_to_coords = {
    0: [{"layer": 1, "start": 0, "end": int(0.25*embedding_dim)}]
}
_ = LIM_trainer.fit(
    X_base_train, 
    y_base_train,    
    iit_data=iit_data,
    intervention_ids_to_coords=id_to_coords)

Finished epoch 15 of 15; error is 2631.4665012359623

In [16]:
base_preds = LIM_trainer.predict_logits(X_base_train.cpu())
print(dataset_semantic.classwise_report(y_base_train.cpu(), base_preds.cpu()))

CLASSIFICATION FOR CLASS:pine
              precision    recall  f1-score   support

         0.0       1.00      1.00      1.00      6972
         1.0       1.00      1.00      1.00        84

    accuracy                           1.00      7056
   macro avg       1.00      1.00      1.00      7056
weighted avg       1.00      1.00      1.00      7056


CLASSIFICATION FOR CLASS:oak
              precision    recall  f1-score   support

         0.0       1.00      1.00      1.00      6972
         1.0       1.00      1.00      1.00        84

    accuracy                           1.00      7056
   macro avg       1.00      1.00      1.00      7056
weighted avg       1.00      1.00      1.00      7056


CLASSIFICATION FOR CLASS:maple
              precision    recall  f1-score   support

         0.0       1.00      1.00      1.00      6972
         1.0       1.00      1.00      1.00        84

    accuracy                           1.00      7056
   macro avg       1.00      1.00   

In [17]:
IIT_preds = LIM_trainer.iit_predict_logits(X_base_train.cpu(),
                                    [X_sources_train[0].cpu()],
                                    interventions.cpu(),
                                    id_to_coords)
print(dataset_semantic.classwise_report(y_IIT_train.cpu(), IIT_preds.cpu(), classes=["animals","plants","skin","move","roots"]))

CLASSIFICATION FOR CLASS:animals
CLASSIFICATION FOR CLASS:plants
CLASSIFICATION FOR CLASS:skin
CLASSIFICATION FOR CLASS:move
CLASSIFICATION FOR CLASS:roots

              precision    recall  f1-score   support

         0.0       1.00      0.98      0.99      6059
         1.0       0.91      1.00      0.95       997

    accuracy                           0.99      7056
   macro avg       0.96      0.99      0.97      7056
weighted avg       0.99      0.99      0.99      7056



              precision    recall  f1-score   support

         0.0       1.00      0.97      0.99      6556
         1.0       0.74      1.00      0.85       500

    accuracy                           0.98      7056
   macro avg       0.87      0.99      0.92      7056
weighted avg       0.98      0.98      0.98      7056



              precision    recall  f1-score   support

         0.0       1.00      0.96      0.98      6235
         1.0       0.75      1.00      0.86       821

    accuracy         

In [18]:
for int_size in [0.0625, 0.125, 0.25,0.5,1.0]:
    for layer in [1,0]:
        id_to_coords = {
        0: [{"layer": layer, "start": 0, "end": int(int_size*embedding_dim)}]
        }
        LIM_trainer.model.load_state_dict(torch.load(PATH))
        _ = LIM_trainer.fit(
        X_base_train, 
        y_base_train,    
        iit_data=iit_data,
        intervention_ids_to_coords=id_to_coords)
        IIT_preds = LIM_trainer.iit_predict_logits(X_base_train.cpu(),
                                    [X_sources_train[0].cpu()],
                                    interventions.cpu(),
                                    id_to_coords)
        print(f"\n\nLayer:{layer} {int_size}")
        print(dataset_semantic.classwise_report(y_IIT_train.cpu(), IIT_preds.cpu(), classes=["animals","plants","skin","move","roots"]))

Finished epoch 15 of 15; error is 2648.3134003281593



Layer:1 0.0625
CLASSIFICATION FOR CLASS:animals
CLASSIFICATION FOR CLASS:plants
CLASSIFICATION FOR CLASS:skin
CLASSIFICATION FOR CLASS:move
CLASSIFICATION FOR CLASS:roots

              precision    recall  f1-score   support

         0.0       1.00      0.97      0.99      6126
         1.0       0.85      1.00      0.92       930

    accuracy                           0.98      7056
   macro avg       0.93      0.99      0.95      7056
weighted avg       0.98      0.98      0.98      7056



              precision    recall  f1-score   support

         0.0       1.00      0.98      0.99      6507
         1.0       0.82      1.00      0.90       549

    accuracy                           0.98      7056
   macro avg       0.91      0.99      0.94      7056
weighted avg       0.99      0.98      0.98      7056



              precision    recall  f1-score   support

         0.0       1.00      0.97      0.98      6168
         1.0       0.81      1.00      0.90       888

    

Finished epoch 15 of 15; error is 3046.5513132810593



Layer:0 0.0625
CLASSIFICATION FOR CLASS:animals
CLASSIFICATION FOR CLASS:plants
CLASSIFICATION FOR CLASS:skin
CLASSIFICATION FOR CLASS:move
CLASSIFICATION FOR CLASS:roots

              precision    recall  f1-score   support

         0.0       1.00      0.88      0.94      6748
         1.0       0.26      0.93      0.41       308

    accuracy                           0.88      7056
   macro avg       0.63      0.91      0.67      7056
weighted avg       0.96      0.88      0.91      7056



              precision    recall  f1-score   support

         0.0       0.99      0.93      0.96      6849
         1.0       0.25      0.83      0.39       207

    accuracy                           0.92      7056
   macro avg       0.62      0.88      0.67      7056
weighted avg       0.97      0.92      0.94      7056



              precision    recall  f1-score   support

         0.0       0.98      0.87      0.92      6687
         1.0       0.21      0.63      0.32       369

    

Finished epoch 15 of 15; error is 2606.7934430837633



Layer:1 0.125
CLASSIFICATION FOR CLASS:animals
CLASSIFICATION FOR CLASS:plants
CLASSIFICATION FOR CLASS:skin
CLASSIFICATION FOR CLASS:move
CLASSIFICATION FOR CLASS:roots

              precision    recall  f1-score   support

         0.0       1.00      0.99      0.99      6052
         1.0       0.92      1.00      0.96      1004

    accuracy                           0.99      7056
   macro avg       0.96      0.99      0.98      7056
weighted avg       0.99      0.99      0.99      7056



              precision    recall  f1-score   support

         0.0       1.00      0.98      0.99      6518
         1.0       0.80      1.00      0.89       538

    accuracy                           0.98      7056
   macro avg       0.90      0.99      0.94      7056
weighted avg       0.98      0.98      0.98      7056



              precision    recall  f1-score   support

         0.0       1.00      0.97      0.98      6171
         1.0       0.81      1.00      0.90       885

    a

Finished epoch 15 of 15; error is 3036.8646034002304



Layer:0 0.125
CLASSIFICATION FOR CLASS:animals
CLASSIFICATION FOR CLASS:plants
CLASSIFICATION FOR CLASS:skin
CLASSIFICATION FOR CLASS:move
CLASSIFICATION FOR CLASS:roots

              precision    recall  f1-score   support

         0.0       0.99      0.87      0.93      6757
         1.0       0.22      0.79      0.34       299

    accuracy                           0.87      7056
   macro avg       0.60      0.83      0.63      7056
weighted avg       0.96      0.87      0.90      7056



              precision    recall  f1-score   support

         0.0       0.99      0.93      0.96      6860
         1.0       0.24      0.82      0.37       196

    accuracy                           0.92      7056
   macro avg       0.62      0.87      0.66      7056
weighted avg       0.97      0.92      0.94      7056



              precision    recall  f1-score   support

         0.0       0.94      0.86      0.90      6521
         1.0       0.17      0.36      0.23       535

    a

Finished epoch 15 of 15; error is 2643.0272701978683



Layer:1 0.25
CLASSIFICATION FOR CLASS:animals
CLASSIFICATION FOR CLASS:plants
CLASSIFICATION FOR CLASS:skin
CLASSIFICATION FOR CLASS:move
CLASSIFICATION FOR CLASS:roots

              precision    recall  f1-score   support

         0.0       1.00      0.97      0.98      6176
         1.0       0.81      1.00      0.89       880

    accuracy                           0.97      7056
   macro avg       0.90      0.98      0.94      7056
weighted avg       0.98      0.97      0.97      7056



              precision    recall  f1-score   support

         0.0       1.00      0.98      0.99      6534
         1.0       0.78      1.00      0.87       522

    accuracy                           0.98      7056
   macro avg       0.89      0.99      0.93      7056
weighted avg       0.98      0.98      0.98      7056



              precision    recall  f1-score   support

         0.0       1.00      0.94      0.97      6334
         1.0       0.66      1.00      0.80       722

    ac

Finished epoch 15 of 15; error is 3032.7772343158724



Layer:0 0.25
CLASSIFICATION FOR CLASS:animals
CLASSIFICATION FOR CLASS:plants
CLASSIFICATION FOR CLASS:skin
CLASSIFICATION FOR CLASS:move
CLASSIFICATION FOR CLASS:roots

              precision    recall  f1-score   support

         0.0       1.00      0.88      0.94      6740
         1.0       0.28      0.98      0.44       316

    accuracy                           0.89      7056
   macro avg       0.64      0.93      0.69      7056
weighted avg       0.97      0.89      0.92      7056



              precision    recall  f1-score   support

         0.0       0.99      0.93      0.96      6828
         1.0       0.28      0.82      0.41       228

    accuracy                           0.93      7056
   macro avg       0.64      0.87      0.69      7056
weighted avg       0.97      0.93      0.94      7056



              precision    recall  f1-score   support

         0.0       0.98      0.88      0.93      6631
         1.0       0.26      0.68      0.38       425

    ac

Finished epoch 15 of 15; error is 2673.2592685222626



Layer:1 0.5
CLASSIFICATION FOR CLASS:animals
CLASSIFICATION FOR CLASS:plants
CLASSIFICATION FOR CLASS:skin
CLASSIFICATION FOR CLASS:move
CLASSIFICATION FOR CLASS:roots

              precision    recall  f1-score   support

         0.0       1.00      0.96      0.98      6184
         1.0       0.80      1.00      0.89       872

    accuracy                           0.97      7056
   macro avg       0.90      0.98      0.93      7056
weighted avg       0.98      0.97      0.97      7056



              precision    recall  f1-score   support

         0.0       1.00      0.97      0.99      6555
         1.0       0.75      1.00      0.85       501

    accuracy                           0.98      7056
   macro avg       0.87      0.99      0.92      7056
weighted avg       0.98      0.98      0.98      7056



              precision    recall  f1-score   support

         0.0       1.00      0.95      0.98      6263
         1.0       0.73      1.00      0.84       793

    acc

Finished epoch 15 of 15; error is 3067.0811516046524



Layer:0 0.5
CLASSIFICATION FOR CLASS:animals
CLASSIFICATION FOR CLASS:plants
CLASSIFICATION FOR CLASS:skin
CLASSIFICATION FOR CLASS:move
CLASSIFICATION FOR CLASS:roots

              precision    recall  f1-score   support

         0.0       1.00      0.89      0.94      6695
         1.0       0.32      0.96      0.48       361

    accuracy                           0.89      7056
   macro avg       0.66      0.92      0.71      7056
weighted avg       0.96      0.89      0.92      7056



              precision    recall  f1-score   support

         0.0       0.99      0.93      0.96      6769
         1.0       0.34      0.80      0.48       287

    accuracy                           0.93      7056
   macro avg       0.67      0.87      0.72      7056
weighted avg       0.96      0.93      0.94      7056



              precision    recall  f1-score   support

         0.0       0.99      0.89      0.94      6642
         1.0       0.33      0.86      0.47       414

    acc

Finished epoch 15 of 15; error is 2769.2073878645897



Layer:1 1.0
CLASSIFICATION FOR CLASS:animals
CLASSIFICATION FOR CLASS:plants
CLASSIFICATION FOR CLASS:skin
CLASSIFICATION FOR CLASS:move
CLASSIFICATION FOR CLASS:roots

              precision    recall  f1-score   support

         0.0       1.00      0.93      0.97      6384
         1.0       0.62      1.00      0.76       672

    accuracy                           0.94      7056
   macro avg       0.81      0.97      0.86      7056
weighted avg       0.96      0.94      0.95      7056



              precision    recall  f1-score   support

         0.0       1.00      0.98      0.99      6537
         1.0       0.77      1.00      0.87       519

    accuracy                           0.98      7056
   macro avg       0.89      0.99      0.93      7056
weighted avg       0.98      0.98      0.98      7056



              precision    recall  f1-score   support

         0.0       1.00      0.92      0.96      6496
         1.0       0.51      1.00      0.68       560

    acc

Finished epoch 15 of 15; error is 3098.4665486812594



Layer:0 1.0
CLASSIFICATION FOR CLASS:animals
CLASSIFICATION FOR CLASS:plants
CLASSIFICATION FOR CLASS:skin
CLASSIFICATION FOR CLASS:move
CLASSIFICATION FOR CLASS:roots

              precision    recall  f1-score   support

         0.0       1.00      0.89      0.94      6675
         1.0       0.34      0.96      0.50       381

    accuracy                           0.89      7056
   macro avg       0.67      0.93      0.72      7056
weighted avg       0.96      0.89      0.92      7056



              precision    recall  f1-score   support

         0.0       1.00      0.95      0.98      6689
         1.0       0.53      0.98      0.69       367

    accuracy                           0.95      7056
   macro avg       0.77      0.97      0.83      7056
weighted avg       0.97      0.95      0.96      7056



              precision    recall  f1-score   support

         0.0       0.99      0.89      0.94      6594
         1.0       0.35      0.83      0.50       462

    acc

## 