# Imports

In [None]:
import os
import torch
from torch import nn
import torch.nn.functional as F
from collections import OrderedDict
from PIL import Image
import argparse

from pl_bolts.models.self_supervised import SimCLR
from pl_bolts.models.self_supervised.simclr import SimCLREvalDataTransform, SimCLRTrainDataTransform

In [None]:
from main import load_data
from main import build_dt
import pickle

# Load Model

## SimCLR

In [None]:
class Flatten(nn.Module):
    def forward(self, input):
        return input[0]

weight_path = 'https://pl-bolts-weights.s3.us-east-2.amazonaws.com/simclr/bolts_simclr_imagenet/simclr_imagenet.ckpt'
simclr = SimCLR.load_from_checkpoint(weight_path, strict=False)
model = nn.Sequential(
    simclr.encoder,
    Flatten(),
    nn.Linear(2048,2))
model.load_state_dict(torch.load('resnet50_simclr_crop_12', map_location='cpu'))
model.eval()

## Swav

In [None]:
import torch
from pl_bolts.models.self_supervised import SwAV

weight_path = 'https://pl-bolts-weights.s3.us-east-2.amazonaws.com/swav/swav_imagenet/swav_imagenet.pth.tar'
model = SwAV.load_from_checkpoint(weight_path, strict=True).model
model.prototypes=nn.Linear(128, 2)

model.load_state_dict(torch.load('resnet50_swav_crop_10', map_location='cpu'))
model.eval()

## Model accuracy test

In [None]:
import utils
import matplotlib.pyplot as plt

from torchvision import datasets, transforms
from torch.utils.data import Dataset, DataLoader


BATCH = 16
#image_size = [224,224]
image_size = [373,373]
train_data = datasets.ImageFolder(root="../../data/chest_xray/train", transform= transforms.Compose([
            transforms.Resize(image_size),
            transforms.CenterCrop(224),
            transforms.ToTensor(),
            transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
            ])
                                 )
train_data_loader = DataLoader(train_data, batch_size=BATCH, num_workers=4)
test_data = datasets.ImageFolder(root="../../data/chest_xray/test",transform= transforms.Compose([
            transforms.Resize(image_size),
            transforms.CenterCrop(224),
            transforms.ToTensor(),
            transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
            ]))


test_data_loader = DataLoader(test_data, batch_size = BATCH, num_workers=4)

#images, labels = next(iter(test_data_loader))
#plt.imshow(images[0][0])

def test_loop(dataloader, model):
    size = len(dataloader.dataset)
    print("Taille ", size)
    num_batches = len(dataloader)
    print("Batchs ", num_batches)
    correct = 0
    with torch.no_grad():
        for X, y in dataloader:
            pred = model(X)
            emb, pred = pred
            
            correct+= (pred.argmax(1)==y).type(torch.float).sum().item()
            print(100*correct/size)
    correct/= size
    return 100*correct
test_loop(test_data_loader, model)

# Build and Train DT model

In [None]:
DATA_LOADER_PATH = "loaders/SimCLR_crop/dataloader.p"
TEST_LOADER_PATH = "loaders/SimCLR_crop/testloader.p"
CROP = True #True for models based on crop images dataset
EMB_OUTPUT = False #True for models outputing the embedding in addition with the prediction
TRAIN_DATA_PATH = "../../data/chest_xray/train"
TEST_DATA_PATH = "../../data/chest_xray/test"
MODEL_PATH = "dtmodels/SimCLR_crop/bestmodel2"

### Training data

In [None]:
if os.path.isfile(DATA_LOADER_PATH):
    train_loader = pickle.load(open(DATA_LOADER_PATH, "rb"))
else:
    train_loader = load_data(model, TRAIN_DATA_PATH, crop=CROP, emb_output = EMB_OUTPUT)
    pickle.dump(train_loader, open( DATA_LOADER_PATH, "wb" ))

### Test data

In [None]:
if os.path.isfile(TEST_LOADER_PATH):
    test_loader = pickle.load(open(TEST_LOADER_PATH, "rb"))
else:
    test_loader = load_data(model, TEST_DATA_PATH, crop=CROP, emb_output = EMB_OUTPUT)
    pickle.dump(test_loader, open(TEST_LOADER_PATH, "wb" ))

### Model parameters

In [None]:
# Training settings
parser = argparse.ArgumentParser(description='DT model')
parser.add_argument('--batch-size', type=int, default=50, metavar='N',
                    help='input batch size for training (default: 64)')
parser.add_argument('--input-dim', type=int, default=224 * 224 * 3, metavar='N',
                    help='input dimension size(default: 224 * 224 * 3)')
parser.add_argument('--output-dim', type=int, default=2, metavar='N',
                    help='output dimension size(default: 2)')
parser.add_argument('--max-depth', type=int, default=2, metavar='N',
                    help='maximum depth of tree(default: 8)')
parser.add_argument('--epochs', type=int, default=10, metavar='N',
                    help='number of epochs to train (default: 5)')
parser.add_argument('--lr', type=float, default=0.01, metavar='LR',
                    help='learning rate (default: 0.01)')
parser.add_argument('--lmbda', type=float, default=0.1, metavar='LR',
                    help='temperature rate (default: 0.1)')
parser.add_argument('--momentum', type=float, default=0.5, metavar='M',
                    help='SGD momentum (default: 0.5)')
parser.add_argument('--no-cuda', action='store_true', default=False,
                    help='disables CUDA training')
parser.add_argument('--seed', type=int, default=1, metavar='S',
                    help='random seed (default: 1)')
parser.add_argument('--log-interval', type=int, default=5, metavar='N',
                    help='how many batches to wait before logging training status')

args = parser.parse_args("")

args.cuda = False
# args.cuda = not args.no_cuda and torch.cuda.is_available()
#torch.manual_seed(args.seed)
#if args.cuda:
    #torch.cuda.manual_seed(args.seed)


### Model training

In [None]:
if os.path.isfile(MODEL_PATH):
    model_tree = pickle.load(open(MODEL_PATH, "rb"))
else:
    model_tree = build_dt(args)
    for epoch in range(1, args.epochs + 1):
        model_tree.train_(train_loader, epoch, crop=CROP)
    pickle.dump(model_tree, open(MODEL_PATH, "wb" ))

# DT model fidelity

In [None]:
model_tree.test_(train_loader, crop=CROP)

# Tree visualization

In [None]:
tree = model_tree.buildTree()

In [None]:
tree.render('dt_viz', view=True)  