In [4]:
import numpy as np
import torch
from preprocess.dataset import get_MNIST,get_dataset,get_handler
from models.model import Model
from al_methods.least_confidence import LeastConfidence
from al_methods.entropy_sampling import EntropySampling
from al_methods.batch_BALD import BatchBALD
from al_methods.core_set import CoreSet
from ssl_methods.semi_fixmatch import fixmatch
from ssl_methods.semi_flexmatch import flexmatch
from ssl_methods.semi_pseudolabel import pseudolabel
from sklearn.metrics import cohen_kappa_score as am_capa
import models
from torchvision import transforms
from framework.framework2 import Framework2
import torch.nn as nn
import time
from torch.utils.data import DataLoader
from scipy.spatial.distance import jensenshannon
import al_methods
import os
import seaborn as sns 
import matplotlib
# matplotlib.use('TkAgg')
import matplotlib.pyplot as plt
sns.set_theme()

### Data Configuration

In [5]:
args_pool = {'mnist':
                { 
                 'n_class':10,
                 'channels':1,
                 'size': 28,
                 'transform_tr': transforms.Compose([
                                transforms.RandomHorizontalFlip(),
                                transforms.ToTensor(), 
                                transforms.Normalize((0.1307,), (0.3081,))]),
                 'transform_te': transforms.Compose([transforms.ToTensor(), 
                                transforms.Normalize((0.1307,), (0.3081,))]),
                 'loader_tr_args':{'batch_size': 128, 'num_workers': 8},
                 'loader_te_args':{'batch_size': 1024, 'num_workers': 8},
                 'normalize':{'mean': (0.1307,), 'std': (0.3081,)},
                },

            'svhn':
                {
                 'n_class':10,
                'channels':3,
                'size': 32,
                'transform_tr': transforms.Compose([ 
                                    transforms.RandomCrop(size = 32, padding=4),
                                    transforms.RandomHorizontalFlip(),
                                    transforms.ToTensor(),
                                    transforms.Normalize((0.4377, 0.4438, 0.4728), (0.1980, 0.2010, 0.1970))]),
                 'transform_te': transforms.Compose([transforms.ToTensor(), 
                                    transforms.Normalize((0.4377, 0.4438, 0.4728), (0.1980, 0.2010, 0.1970))]),
                 'loader_tr_args':{'batch_size': 128, 'num_workers': 8},
                 'loader_te_args':{'batch_size': 1024, 'num_workers': 8},
                 'normalize':{'mean': (0.4377, 0.4438, 0.4728), 'std': (0.1980, 0.2010, 0.1970)},
                },
            'cifar10':
                {
                 'n_class':10,
                 'channels':3,
                 'size': 32,
                 'transform_tr': transforms.Compose([
                                    transforms.RandomCrop(size = 32, padding=4),
                                    transforms.RandomHorizontalFlip(),
                                    transforms.ToTensor(), 
                                    transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2470, 0.2435, 0.2616))]),
                 'transform_te': transforms.Compose([transforms.ToTensor(), 
                                    transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2470, 0.2435, 0.2616))]),
                 'loader_tr_args':{'batch_size': 256, 'num_workers': 8},
                 'loader_te_args':{'batch_size': 512, 'num_workers': 8},
                 'normalize':{'mean': (0.4914, 0.4822, 0.4465), 'std': (0.2470, 0.2435, 0.2616)},
                 },


            'cifar100': 
               {
                'n_class':100,
                'channels':3,
                'size': 32,
                'transform_tr': transforms.Compose([
                                transforms.RandomCrop(size = 32, padding=4),
                                transforms.RandomHorizontalFlip(),
                                transforms.ToTensor(), 
                                transforms.Normalize((0.5071, 0.4867, 0.4408), (0.2675, 0.2565, 0.2761))]),
                'transform_te': transforms.Compose([transforms.ToTensor(), 
                                transforms.Normalize((0.5071, 0.4867, 0.4408), (0.2675, 0.2565, 0.2761))]),
                'loader_tr_args':{'batch_size': 2048, 'num_workers': 4},
                'loader_te_args':{'batch_size': 512, 'num_workers': 8},
                'normalize':{'mean': (0.5071, 0.4867, 0.4408), 'std': (0.2675, 0.2565, 0.2761)},
                }
        }

### Dataset

In [6]:
# in the main 
X_tr, Y_tr, X_te, Y_te = get_dataset("Mnist", "./datasets")

In [7]:
# in the main 
if type(X_tr) is list:
    X_tr = np.array(X_tr)
    Y_tr = torch.tensor(np.array(Y_tr))
    X_te = np.array(X_te)
    Y_te = torch.tensor(np.array(Y_te))

if type(X_tr[0]) is not np.ndarray:
    X_tr = X_tr.numpy()
    X_te = X_te.numpy()

In [8]:
# in the main
n_pool = len(Y_tr)
n_test = len(Y_te)
#in the main
handler = get_handler("mnist")
# main or framewrok to see
nEnd=100 # total number to query 
nQuery=10 # nombre of points to query in batch 
nStart=1# nbre of points to start
NUM_INIT_LB = int(nStart*n_pool/100)
NUM_QUERY = int(nQuery*n_pool/100) if nStart!= 100 else 0
NUM_ROUND = int((int(nEnd*n_pool/100) - NUM_INIT_LB)/ NUM_QUERY) if nStart!= 100 else 0
if NUM_QUERY != 0:
    if (int(nEnd*n_pool/100) - NUM_INIT_LB)% NUM_QUERY != 0:
        NUM_ROUND += 1
print("begin with : ",NUM_INIT_LB," Number of round",NUM_ROUND," How many to query",NUM_QUERY)


begin with :  600  Number of round 10  How many to query 6000


In [9]:
# in the main file
idxs_lb = np.zeros(n_pool, dtype=bool)
idxs_lb
# in the main file 
idxs_tmp = np.arange(n_pool)
idxs_tmp
np.random.shuffle(idxs_tmp)
idxs_tmp
# in the main file
idxs_lb[idxs_tmp[:NUM_INIT_LB]] = True
idxs_lb

array([False, False, False, ..., False, False, False])

### Args

In [10]:
class Args:
    def __init__(self,n_class,img_size,channels,transform_tr,transform_te,loader_tr_args,loader_te_args,normalize):
        self.n_class=n_class
        self.img_size=img_size
        self.channels=channels
        self.transform_tr=transform_tr
        self.transform_te=transform_te
        self.loader_tr_args=loader_tr_args
        self.loader_te_args=loader_te_args
        self.normalize=normalize
        self.dataset='mnist'
        self.save_path='./save'
        self.model='ResNet50'
        self.lr=0.1
        self.schedule = [20, 40]
        self.momentum=0.9
        self.gammas=[0.1,0.1]
        self.framework='framwork1'
        self.optimizer='SGD'
        self.save_model=False
        self.ALstrat='LeastConfidence'
        self.SSLstrat='fixmatch'
        self.n_epoch=10

In [11]:
dataset_args = args_pool["mnist"]
n_class = dataset_args['n_class']
img_size = dataset_args['size']
channels = dataset_args['channels']
transform_tr = dataset_args['transform_tr']
transform_te = dataset_args['transform_te']
loader_tr_args = dataset_args['loader_tr_args']
loader_te_args = dataset_args['loader_te_args']
normalize = dataset_args['normalize']


In [12]:
args=Args(n_class,img_size,channels,transform_tr,transform_te,loader_tr_args,loader_te_args,normalize)

## Models

In [13]:
model=models.__dict__["MobileNet"](n_class=args_pool['mnist']['n_class'])

TypeError: MobileNet.__init__() got an unexpected keyword argument 'n_class'

In [11]:
model.feature_extractor.conv1=torch.nn.Conv2d(args_pool['mnist']['channels'],16,kernel_size=3,stride=1,padding=1,bias=False)
model.discriminator.dis_fc2=torch.nn.Linear(in_features=50,out_features=args_pool['mnist']['n_class'],bias=True)

### Framework 2

In [12]:
framework_2= Framework2(X_tr, Y_tr, X_te, Y_te, idxs_lb, model, handler, args)

#### Train model father

In [13]:
print(f' Strategy for active learning{args.ALstrat} and strategy for semi-supervised learning used {args.SSLstrat}')

stratAl_model_1=LeastConfidence(framework_2.X, framework_2.Y, framework_2.X_te, framework_2.Y_te, framework_2.idxs_lb, framework_2.net, framework_2.handler, framework_2.args,framework_2.n_pool,framework_2.device)
stratSSL_model_1=pseudolabel(framework_2.X, framework_2.Y, framework_2.X_te, framework_2.Y_te, framework_2.idxs_lb, framework_2.net, framework_2.handler, framework_2.args,framework_2.n_pool,framework_2.device,framework_2.predict,framework_2.g)

# Train father
framework_2.train(alpha=2e-3,n_epoch=5)


 Strategy for active learningLeastConfidence and strategy for semi-supervised learning used fixmatch
Let's use 1 GPUs!
[Batch=000] [Loss=2.44]

==>>[2023-09-04 12:24:50] [Epoch=000/005] [framwork1(LeastConfidence+fixmatch) Need: 00:00:00] [LR=0.1000] [Best : Test Accuracy=0.00, Error=1.00]
[Batch=000] [Loss=2.43]

==>>[2023-09-04 12:24:52] [Epoch=001/005] [framwork1(LeastConfidence+fixmatch) Need: 00:00:13] [LR=0.1000] [Best : Test Accuracy=0.11, Error=0.89]
[Batch=000] [Loss=2.33]

==>>[2023-09-04 12:24:54] [Epoch=002/005] [framwork1(LeastConfidence+fixmatch) Need: 00:00:07] [LR=0.1000] [Best : Test Accuracy=0.11, Error=0.89]
[Batch=000] [Loss=2.17]

==>>[2023-09-04 12:24:56] [Epoch=003/005] [framwork1(LeastConfidence+fixmatch) Need: 00:00:04] [LR=0.1000] [Best : Test Accuracy=0.11, Error=0.89]
[Batch=000] [Loss=2.07]

==>>[2023-09-04 12:24:58] [Epoch=004/005] [framwork1(LeastConfidence+fixmatch) Need: 00:00:02] [LR=0.1000] [Best : Test Accuracy=0.12, Error=0.88]
---- save figure the 

0.1298

In [14]:
test_acc=framework_2.predict(framework_2.X_te,framework_2.Y_te)
acc = np.zeros(NUM_ROUND+1)
acc[0] = test_acc

In [15]:
acc

array([0.1298, 0.    , 0.    , 0.    , 0.    , 0.    ])

In [16]:
for rd in range(1, NUM_ROUND+1):

    labeled = len(np.arange(framework_2.n_pool)[framework_2.idxs_lb])
    if NUM_QUERY > int(nEnd*framework_2.n_pool/100) - labeled:
        NUM_QUERY = int(nEnd*framework_2.n_pool/100) - labeled
        
    # query
    ts = time.time()
    output = stratAl_model_1.query(NUM_QUERY)
    q_idxs = output
    #predict father and son
    prediction=framework_2.predict_coefficient(framework_2.X, framework_2.Y,q_idxs)
    prediction_w=framework_2.predict_coefficient_w(framework_2.X, framework_2.Y,q_idxs)
    print("father: ", prediction)
    print("Son : ",prediction_w)
    # compute the coefficient
    cof=am_capa(prediction,prediction_w)
    print(cof)    
    # update
    framework_2.idxs_lb[q_idxs] = True
    te = time.time()
    tp = te - ts
    framework_2.update(framework_2.idxs_lb)

    if cof<0.8:
        # Al_methods
        print( 'AL Methods')
        print('Round {}/{}'.format(rd, NUM_ROUND), flush=True)

        if hasattr(stratAl_model_1, 'train'):
        
            best_test_acc=stratAl_model_1.train(alpha=2e-3, n_epoch=5)
        else: best_test_acc = framework_2.train(alpha=2e-3, n_epoch=5)
        print(best_test_acc)


        t_iter = time.time() - ts
        
        # round accuracy
        # test_acc = strategy.predict(X_te, Y_te)
        acc[rd] = best_test_acc
    else:
        #SSL methods
        print("SSL Methods")
        print('Round {}/{}'.format(rd, NUM_ROUND), flush=True)
        best_test_acc = stratSSL_model_1.train(alpha=2e-3, n_epoch=5)


        t_iter = time.time() - ts
        
        # round accuracy
        # test_acc = strategy.predict(X_te, Y_te)
        acc[rd] = best_test_acc

father:  [4 4 4 2 5 4 2 4 4 4 4 4 4 9 4 4 9 2 2 4 4 9 4 4 4 2 4 4 7 4 4 4 4 4 4 2 4
 4 4 4 4 4 4 2 4 4 4 1 1 2 4 4 1 1 5 4 4 1 2 2 4 5 2 4 2 1 2 7 1 1 2 5 2 2
 1 5 7 2 2 5 1 1 5 1 2 5 1 5 1 2 5 2 1 8 1 1 5 5 2 1 2 7 5 2 1 5 2 5 5 5 5
 2 5 5 5 1 5 5 5 5 2 1 2 1 5 1 2 5 1 1 1 1 5 5 1 7 2 5 5 5 2 7 5 5 1 2 1 5
 1 2 1 1 5 1 5 2 5 5 1 5 1 1 7 2 7 7 1 1 7 1 2 1 1 1 1 1 2 1 2 1 2 1 2 1 1
 1 2 1 1 2 2 1 1 2 1 1 1 1 2 2 2 1 1 7 2 2 2 1 2 7 1 1 1 7 2 2 1 2 2 1 1 1
 2 1 2 1 7 2 2 1 1 2 2 1 1 1 1 1 7 7 1 1 1 1 1 1 7 1 2 1 2 2 1 2 2 1 1 1 1
 1 1 1 1 1 1 2 7 1 2 2 1 1 2 2 2 1 1 1 1 2 2 1 2 2 1 1 1 2 1 1 1 7 1 2 2 2
 7 1 2 1 2 1 1 1 1 1 2 1 1 2 1 1 7 2 1 2 1 1 1 1 7 1 1 1 1 2 2 1 1 2 1 1 1
 2 2 1 2 1 7 2 1 2 1 1 1 2 1 1 2 1 7 1 2 1 2 1 1 1 2 1 1 1 1 2 2 7 1 1 1 2
 2 2 1 7 1 1 1 7 1 2 2 2 2 1 2 2 2 1 2 2 2 2 1 1 1 1 1 2 1 2 2 1 7 1 1 1 2
 2 1 1 2 2 2 1 2 1 7 1 1 1 2 2 7 2 2 2 1 1 1 2 1 1 2 1 2 1 1 2 7 2 1 7 2 1
 2 2 1 1 2 1 7 1 2 1 1 1 1 2 1 2 2 1 1 7 2 1 1 0 1 2 1 2 2 2 1 1 1 2 1 1 1
 1 2 2 2 1 2 1 1

In [18]:
acc

array([0.1298    , 0.64240003, 0.68559998, 0.82969999, 0.9357    ,
       0.95529997])