In [None]:
import numpy as np
import torch
from preprocess.dataset import get_MNIST,get_dataset,get_handler
from models.model import Model
from al_methods.least_confidence import LeastConfidence
from ssl_methods.semi_fixmatch import fixmatch
import models
from torchvision import transforms
from framework.framework1 import Framework1
import framework
import torch.nn as nn
import time
from torch.utils.data import DataLoader
from scipy.spatial.distance import jensenshannon
import al_methods
import os 

In [None]:
args_pool = {'mnist':
                { 
                 'n_class':10,
                 'channels':1,
                 'size': 28,
                 'transform_tr': transforms.Compose([
                                transforms.RandomHorizontalFlip(),
                                transforms.ToTensor(), 
                                transforms.Normalize((0.1307,), (0.3081,))]),
                 'transform_te': transforms.Compose([transforms.ToTensor(), 
                                transforms.Normalize((0.1307,), (0.3081,))]),
                 'loader_tr_args':{'batch_size': 128, 'num_workers': 8},
                 'loader_te_args':{'batch_size': 1024, 'num_workers': 8},
                 'normalize':{'mean': (0.1307,), 'std': (0.3081,)},
                },

            'svhn':
                {
                 'n_class':10,
                'channels':3,
                'size': 32,
                'transform_tr': transforms.Compose([ 
                                    transforms.RandomCrop(size = 32, padding=4),
                                    transforms.RandomHorizontalFlip(),
                                    transforms.ToTensor(),
                                    transforms.Normalize((0.4377, 0.4438, 0.4728), (0.1980, 0.2010, 0.1970))]),
                 'transform_te': transforms.Compose([transforms.ToTensor(), 
                                    transforms.Normalize((0.4377, 0.4438, 0.4728), (0.1980, 0.2010, 0.1970))]),
                 'loader_tr_args':{'batch_size': 128, 'num_workers': 8},
                 'loader_te_args':{'batch_size': 1024, 'num_workers': 8},
                 'normalize':{'mean': (0.4377, 0.4438, 0.4728), 'std': (0.1980, 0.2010, 0.1970)},
                },
            'cifar10':
                {
                 'n_class':10,
                 'channels':3,
                 'size': 32,
                 'transform_tr': transforms.Compose([
                                    transforms.RandomCrop(size = 32, padding=4),
                                    transforms.RandomHorizontalFlip(),
                                    transforms.ToTensor(), 
                                    transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2470, 0.2435, 0.2616))]),
                 'transform_te': transforms.Compose([transforms.ToTensor(), 
                                    transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2470, 0.2435, 0.2616))]),
                 'loader_tr_args':{'batch_size': 256, 'num_workers': 8},
                 'loader_te_args':{'batch_size': 512, 'num_workers': 8},
                 'normalize':{'mean': (0.4914, 0.4822, 0.4465), 'std': (0.2470, 0.2435, 0.2616)},
                 },


            'cifar100': 
               {
                'n_class':100,
                'channels':3,
                'size': 32,
                'transform_tr': transforms.Compose([
                                transforms.RandomCrop(size = 32, padding=4),
                                transforms.RandomHorizontalFlip(),
                                transforms.ToTensor(), 
                                transforms.Normalize((0.5071, 0.4867, 0.4408), (0.2675, 0.2565, 0.2761))]),
                'transform_te': transforms.Compose([transforms.ToTensor(), 
                                transforms.Normalize((0.5071, 0.4867, 0.4408), (0.2675, 0.2565, 0.2761))]),
                'loader_tr_args':{'batch_size': 2048, 'num_workers': 4},
                'loader_te_args':{'batch_size': 512, 'num_workers': 8},
                'normalize':{'mean': (0.5071, 0.4867, 0.4408), 'std': (0.2675, 0.2565, 0.2761)},
                }
        }

In [None]:
# in the main 
X_tr, Y_tr, X_te, Y_te = get_dataset("Mnist", "./datasets")

In [None]:
# in the main 
if type(X_tr) is list:
    X_tr = np.array(X_tr)
    Y_tr = torch.tensor(np.array(Y_tr))
    X_te = np.array(X_te)
    Y_te = torch.tensor(np.array(Y_te))

if type(X_tr[0]) is not np.ndarray:
    X_tr = X_tr.numpy()
    X_te = X_te.numpy()
    

In [None]:
# in the main
n_pool = len(Y_tr)
n_test = len(Y_te)
#in the main
handler = get_handler("mnist")
# main or framewrok to see
nEnd=50 # total number to query 
nQuery=1 # nombre of points to query in batch 
nStart=1 # nbre of points to start
NUM_INIT_LB = int(nStart*n_pool/100)
NUM_QUERY = int(nQuery*n_pool/100) if nStart!= 100 else 0
NUM_ROUND = int((int(nEnd*n_pool/100) - NUM_INIT_LB)/ NUM_QUERY) if nStart!= 100 else 0
if NUM_QUERY != 0:
    if (int(nEnd*n_pool/100) - NUM_INIT_LB)% NUM_QUERY != 0:
        NUM_ROUND += 1
print(NUM_INIT_LB,NUM_ROUND,NUM_QUERY)


## ResNet 18 /50

In [None]:
model=models.__dict__["ResNet50"](n_class=args_pool['mnist']['n_class'])

In [None]:
model.feature_extractor.conv1=torch.nn.Conv2d(args_pool['mnist']['channels'],16,kernel_size=3,stride=1,padding=1,bias=False)
model.discriminator.dis_fc2=torch.nn.Linear(in_features=50,out_features=args_pool['mnist']['n_class'],bias=True)

In [None]:
model

## MobileNet

In [None]:
model=models.__dict__["MobileNet"]()

In [None]:
model

In [None]:
model.conv1=torch.nn.Conv2d(args_pool['mnist']['channels'],32,kernel_size=3,stride=1,padding=1,bias=False)
model.linear=torch.nn.Linear(in_features=1024, out_features=args_pool['mnist']['n_class'], bias=True)

In [None]:
model

## vgg

In [None]:
model=models.__dict__["VGG"](vgg_name='VGG16')

In [None]:
model

In [None]:
model.features[0]=torch.nn.Conv2d(args_pool['mnist']['channels'],64,kernel_size=3,stride=1,padding=1,bias=False)
model.classifier=torch.nn.Linear(in_features=512, out_features=args_pool['mnist']['n_class'], bias=True)

In [None]:
model

In [None]:

# model=Model('resnet50').get_model()
# model


In [None]:
model.conv1 = torch.nn.Conv2d(args_pool['mnist']['channels'], 64, kernel_size=7, stride=2, padding=3, bias=False)  # Conversion en 3 canaux
model.fc = torch.nn.Linear(2048, args_pool['mnist']['n_class'])  # Modifier la couche de classification pour 10 classes
model

In [None]:
# in the main file
idxs_lb = np.zeros(n_pool, dtype=bool)
idxs_lb
# in the main file 
idxs_tmp = np.arange(n_pool)
idxs_tmp
np.random.shuffle(idxs_tmp)
idxs_tmp
# in the main file
idxs_lb[idxs_tmp[:NUM_INIT_LB]] = True
idxs_lb

In [None]:
class Args:
    def __init__(self,n_class,img_size,channels,transform_tr,transform_te,loader_tr_args,loader_te_args,normalize):
        self.n_class=n_class
        self.img_size=img_size
        self.channels=channels
        self.transform_tr=transform_tr
        self.transform_te=transform_te
        self.loader_tr_args=loader_tr_args
        self.loader_te_args=loader_te_args
        self.normalize=normalize
        self.dataset='mnist'
        self.save_path='./save'
        self.model='ResNet50'
        self.lr=0.1
        self.schedule = [20, 40]
        self.momentum=0.9
        self.gammas=[0.1,0.1]
        self.framework='framwork1'
        self.optimizer='SGD'
        self.save_model=False
        self.ALstrat='LeastConfidence'
        self.SSLstrat='fixmatch'

In [None]:
dataset_args = args_pool["mnist"]
n_class = dataset_args['n_class']
img_size = dataset_args['size']
channels = dataset_args['channels']
transform_tr = dataset_args['transform_tr']
transform_te = dataset_args['transform_te']
loader_tr_args = dataset_args['loader_tr_args']
loader_te_args = dataset_args['loader_te_args']
normalize = dataset_args['normalize']

In [None]:
args=Args(n_class,img_size,channels,transform_tr,transform_te,loader_tr_args,loader_te_args,normalize)

In [None]:
framework_1= Framework1(X_tr, Y_tr, X_te, Y_te, idxs_lb, model, handler, args)

In [None]:
print(f' Sratgey for active learning{args.ALstrat} and strategy for semi-supervised learning used {args.SSLstrat}')
stratAl=LeastConfidence(framework_1.X_tr, framework_1.Y_tr, framework_1.X_te, framework_1.Y_te, framework_1.idxs_lb, framework_1.net, framework_1.handler, framework_1.args,framework_1.n_pool,framework_1.device)
stratSSL=fixmatch(framework_1.X_tr, framework_1.Y_tr, framework_1.X_te, framework_1.Y_te, framework_1.idxs_lb, framework_1.net, framework_1.handler, framework_1.args,framework_1.n_pool,framework_1.device,framework_1.predict,framework_1.g)


framework_1.train(alpha=2e-3,n_epoch=10)

test_acc=framework_1.predict(framework_1.X_te,framework_1.Y_te)
acc = np.zeros(NUM_ROUND+1)
acc[0] = test_acc

for rd in range(0, NUM_ROUND):
    
    if rd%2==0:
        # Al_methods
        print('Round {}/{}'.format(rd, NUM_ROUND), flush=True)
        labeled = len(np.arange(framework_1.n_pool)[framework_1.idxs_lb])
        if NUM_QUERY > int(nEnd*framework_1.n_pool/100) - labeled:
            NUM_QUERY = int(nEnd*framework_1.n_pool/100) - labeled
            
        # query
        ts = time.time()
        output = stratAl.query(NUM_QUERY)
        q_idxs = output
        framework_1.idxs_lb[q_idxs] = True
        te = time.time()
        tp = te - ts
        
        # update
        framework_1.update(framework_1.idxs_lb)
        if hasattr(stratAl, 'train'):
        
            best_test_acc=stratAl.train(alpha=2e-3, n_epoch=10)
        else: best_test_acc = framework_1.train(alpha=2e-3, n_epoch=10)

        t_iter = time.time() - ts
        
        # round accuracy
        # test_acc = strategy.predict(X_te, Y_te)
        acc[rd] = best_test_acc
    else:
        #SSL methods
        
        print('Round {}/{}'.format(rd, NUM_ROUND), flush=True)
        labeled = len(np.arange(framework_1.n_pool)[framework_1.idxs_lb])
        if NUM_QUERY > int(nEnd*framework_1.n_pool/100) - labeled:
            NUM_QUERY = int(nEnd*framework_1.n_pool/100) - labeled
            
        # query
        ts = time.time()

        output = stratSSL.query(NUM_QUERY)
        q_idxs = output
        framework_1.idxs_lb[q_idxs] = True
        te = time.time()
        tp = te - ts
        
        # update
        framework_1.update(framework_1.idxs_lb)
        best_test_acc = stratSSL.train(alpha=2e-3, n_epoch=10)

        t_iter = time.time() - ts
        
        # round accuracy
        # test_acc = strategy.predict(X_te, Y_te)
        acc[rd] = best_test_acc



In [None]:
acc

In [None]:
folder_result_acc='results'
# out_file = os.path.join(args.save_path, args.save_file)
if not os.path.exists(folder_result_acc):
    os.mkdir(folder_result_acc)
    print(f"Folder '{folder_result_acc}' created succesfuly.")
file_path=os.path.join(folder_result_acc,args.framework+"("+args.ALstrat+args.SSLstrat+")")
np.save(file_path,acc)

In [None]:
np.load("results/framwork1(LeastConfidencefixmatch).npy")