In [1]:
import torch
import torchvision
from torchvision.transforms import v2
from torchvision.models import efficientnet_b0,EfficientNet_B0_Weights,densenet121,DenseNet121_Weights
from torch.utils.data import DataLoader
import skorch
from skorch.helper import predefined_split
from skorch.callbacks import Checkpoint,Freezer
import numpy as np
from sklearn.metrics import roc_auc_score,f1_score



In [2]:
class KONet(torch.nn.Module):

    def __init__(
            self,
            m1_ratio=0.6,
            m2_ratio=0.4,
            m1_dropout=0.1,
            m2_dropout=0.3,
            n_classes=2
    ):
        super().__init__()
        assert m1_ratio+m2_ratio==1
        self.n_classes=n_classes
        self.m1_ratio=m1_ratio
        self.m2_ratio=m2_ratio
        self.m1_dropout=m1_dropout
        self.m2_dropout=m2_dropout

        self.efficient=efficientnet_b0(weights=EfficientNet_B0_Weights.DEFAULT)
        self.efficient.classifier[0]=torch.nn.Dropout(p=self.m1_dropout,inplace=True)
        self.efficient.classifier[-1]=torch.nn.Linear(in_features=1280,out_features=self.n_classes)

        self.dense=densenet121(weights=DenseNet121_Weights.DEFAULT)
        self.dense.classifier=torch.nn.Sequential(torch.nn.Dropout(p=self.m2_dropout,inplace=True),
                                            torch.nn.Linear(in_features=1024,out_features=n_classes),
                                            )

    def forward(self, x):
        m1=self.efficient(x)
        m2=self.dense(x)
        out=self.m1_ratio*m1+self.m2_ratio*m2
        return out

In [41]:
n_classes=2
image_shape=224
augmented_dataset_size=4000
path="D:\Osteoporosis detection\datasets\Osteoporosis Knee X-ray modified\Osteoporosis Knee X-ray"
non_augment_transform=v2.Compose([v2.ToImageTensor(),
                       v2.ToDtype(torch.float32),
                       v2.Resize((image_shape,image_shape),antialias=True),
                       v2.Normalize(mean=[0],std=[1]),
                       ])
transforms=v2.Compose([v2.ToImageTensor(),
                       v2.ToDtype(torch.float32),
                       v2.RandomAffine(degrees=30,shear=30),
                       v2.RandomZoomOut(side_range=(1,1.5)),
                       v2.Resize((image_shape,image_shape),antialias=True),
                       v2.Normalize(mean=[0],std=[1]),
                       ])

In [42]:
non_augmented_dataset=torchvision.datasets.ImageFolder(path,transform=non_augment_transform)
dataset=torchvision.datasets.ImageFolder(path,transform=transforms)
factor=augmented_dataset_size//len(dataset)-1

new_dataset=torch.utils.data.ConcatDataset([non_augmented_dataset]+[dataset for _ in range(factor)])
del non_augmented_dataset,dataset

In [43]:
#dataset=torchvision.datasets.ImageFolder(path,transform=transforms)
generator1 = torch.Generator().manual_seed(42)
train_split=0.2
valid_split=0.1
test_split=0.7
train_set,valid_set,test_set=torch.utils.data.random_split(new_dataset, [train_split,valid_split,test_split],
                                                            generator=generator1)

In [6]:
class distiller(torch.nn.Module):
    def __init__(
            self,
            large_model,
            small_model
    ):
        super().__init__()
        self.large_model=large_model
        self.small_model=small_model

    def forward(self, x):
        large_output=self.large_model(x)
        small_output=self.small_model(x)
        print(large_output)
        print(small_output)
        return (small_output,large_output)

In [7]:
model_name='distiller'
large_model_name='dense'
small_model_name='conv_next'

In [8]:
#Large model initiallization
if large_model_name=='dense':
    large_model=densenet121(weights=DenseNet121_Weights.DEFAULT)
    p=0.3
    large_model.classifier=torch.nn.Sequential(torch.nn.Dropout(p=p,inplace=True),
                                        torch.nn.Linear(in_features=1024,out_features=n_classes),
                                        )
elif large_model_name=='KONetOtherFinetuned':
    m1_ratio=0.6
    m2_ratio=0.4
    m1_dropout=0.1
    m2_dropout=0.3
    large_model=KONet(m1_ratio=m1_ratio,m2_ratio=m2_ratio,m1_dropout=m1_dropout,m2_dropout=m2_dropout,n_classes=n_classes)
#model.load_state_dict(torch.load(f'model/{model_name}best_param.pkl'))
if small_model_name=='conv_next':
    p=0.3
    small_model=torchvision.models.convnext_tiny(weights='DEFAULT')
    small_model.classifier[2]=torch.nn.Sequential(torch.nn.Dropout(p=p,inplace=True),
                                        torch.nn.Linear(in_features=768,out_features=n_classes),
                                        )
model=distiller(large_model=large_model,small_model=small_model)
#Freeze entirety of large model so only small model changes
freeze=['large_model.*.weight']

In [9]:
from typing import Optional
from torch import Tensor
from torch.nn.modules.loss import _WeightedLoss
from torch.nn import functional as F
#Now we need to create our own loss function which will perform cross entropy loss
class distill_loss(_WeightedLoss):
    __constants__ = ['ignore_index', 'reduction', 'label_smoothing']
    ignore_index: int
    label_smoothing: float

    def __init__(self, weight: Optional[Tensor] = None, size_average=None, ignore_index: int = -100,
                 reduce=None, reduction: str = 'mean', label_smoothing: float = 0.0, T:int = 2,
                 soft_target_loss_weight: float = 0.25, ce_loss_weight: float = 0.75,) -> None:
        super().__init__(weight, size_average, reduce, reduction)
        self.ignore_index = ignore_index
        self.label_smoothing = label_smoothing
        self.T=T
        self.soft_target_loss_weight=soft_target_loss_weight
        self.ce_loss_weight=ce_loss_weight

    def forward(self, input: Tensor, target: Tensor) -> Tensor:
        soft_targets = F.softmax(input[1] / self.T, dim=-1)
        soft_prob = F.softmax(input[0] / self.T, dim=-1)
        #print(soft_targets)
        #print(soft_prob)
        soft_targets_loss = -torch.sum(soft_targets * (soft_prob.log())) / soft_prob.size()[0] * (self.T**2)
        #print(soft_targets_loss)
        label_loss = F.cross_entropy(input[0], target, weight=self.weight,
                               ignore_index=self.ignore_index, reduction=self.reduction,
                               label_smoothing=self.label_smoothing)  
        #print(label_loss) 
        loss = self.soft_target_loss_weight * soft_targets_loss + self.ce_loss_weight * label_loss
        return loss

In [45]:
monitor = lambda net: any(net.history[-1, ('valid_accuracy_best','valid_loss_best')])
cp=Checkpoint(monitor='valid_loss_best',dirname='model',f_params=f'{model_name}best_param.pkl',
               f_optimizer=f'{model_name}best_opt.pkl', f_history=f'{model_name}best_history.json')
cb = skorch.callbacks.Freezer(freeze)
classifier = skorch.NeuralNetClassifier(
        model,
        criterion=distill_loss(),
        optimizer=torch.optim.AdamW,
        train_split=predefined_split(valid_set),
        iterator_train=DataLoader,
        iterator_valid=DataLoader,
        iterator_train__shuffle=True,
        iterator_train__pin_memory=True,
        iterator_valid__pin_memory=True,
        #iterator_train__num_workers=1,
        #iterator_valid__num_workers=1,
        #iterator_train__persistent_workers=True,
        #iterator_valid__persistent_workers=True,
        batch_size=32,
        device='cuda',
        callbacks=[cp,cb],#Try to implement accuracy and f1 score callables here
        warm_start=True,
        )
classifier.initialize()
classifier.module_.large_model.load_state_dict(torch.load(f'model/{large_model_name}best_param.pkl'))

<All keys matched successfully>

In [11]:
test=np.ones((1,3,image_shape,image_shape),dtype=np.float32)
out=classifier.predict_proba(test)

tensor([[ 2.1463, -2.1091]], device='cuda:0')
tensor([[0.0024, 0.2747]], device='cuda:0')


In [None]:
classifier.fit(train_set,y=None,epochs=2)

In [46]:
classifier.load_params(f_params='model/distiller_otherbest_param.pkl')
distilled_model=classifier.module_.small_model
classifier.module_=distilled_model

In [49]:
small_model_name

'conv_next'

In [50]:
classifier.save_params(f_params=f'model/{small_model_name}_distilled_otherbest_param.pkl')

In [48]:
iterations=5
accuracy=[]
f1=[]
auc=[]
test_loader=DataLoader(test_set,batch_size=8,shuffle=False,num_workers=4,pin_memory=True,persistent_workers=True)
for i in range(iterations):
    print(i)
    probs=[]
    actual_labels=[]
    for test_features, actual_lb in iter(test_loader):
        prob=classifier.predict_proba(test_features)
        actual_lb=np.array(actual_lb)
        probs.append(prob)
        actual_labels.append(actual_lb)

    probs=np.concatenate(probs)
    pred_labels=np.argmax(probs,axis=1)
    actual_labels=np.concatenate(actual_labels)

    iteration_auc=roc_auc_score(actual_labels,probs[:,1])
    iteration_accuracy=np.mean(pred_labels==actual_labels)
    iteration_f1=f1_score(actual_labels,pred_labels)

    accuracy.append(iteration_accuracy)
    f1.append(iteration_f1)
    auc.append(iteration_auc)

print(small_model_name)

print(f"Accuracy mean: {np.mean(accuracy)} standard deviation: {np.std(accuracy)}")
print(f"F1-Score mean: {np.mean(f1)} standard deviation: {np.std(f1)}")
print(f"ROC_AUC  mean: {np.mean(auc)} standard deviation: {np.std(auc)}")

0
1
2
3
4
conv_next
Accuracy mean: 0.9987734487734488 standard deviation: 0.0004893455976280795
F1-Score mean: 0.998925789594517 standard deviation: 0.0004294201753277975
ROC_AUC  mean: 0.9999944733530592 standard deviation: 8.16780007857457e-06
