In [58]:
import torch
import torch.nn as nn
from torchvision import transforms
from torchvision import datasets
from torchvision.models import resnet34,ResNet34_Weights,resnet18,vgg11
import torch.nn.functional as F

In [29]:
class DCGENERATOR(nn.Module):
    def __init__(self,latent_dim = 100,img_channels = 3,feature_maps = 64):
        super().__init__()
        # the input goes like (B,latent,1,1)
        self.model = nn.Sequential(
            nn.ConvTranspose2d(in_channels=latent_dim,out_channels=feature_maps * 4,kernel_size=4,padding=0,stride=1),
            nn.BatchNorm2d(feature_maps*4),
            nn.LeakyReLU(negative_slope=0.2,inplace=True),
            # outputs = (1-1)*1-2*0+4 =(256,4,4)

            nn.ConvTranspose2d(in_channels=256,out_channels=feature_maps*2,kernel_size=4,stride=2,padding=1),
            nn.BatchNorm2d(feature_maps*2),
            nn.LeakyReLU(negative_slope=0.2,inplace=True),
            #okay so now the shape becomes (4-1) * 2 - 2*1+4 = (8,8,128)

            nn.ConvTranspose2d(in_channels=128,out_channels=feature_maps,stride=2,kernel_size=4,padding=1),
            nn.BatchNorm2d(feature_maps),
            nn.LeakyReLU(negative_slope = 0.2,inplace=True),
            #okay so the shape becomes (8-1) * 2 + 4- 2 * 1 + 0 = (16,16,64)

            nn.ConvTranspose2d(in_channels=64,out_channels=3,kernel_size=4,stride=2,padding=1),
            nn.Tanh(),
        )
    def forward(self,z):
        z = z.view(z.size(0),z.size(1),1,1)
        return self.model(z)    

In [104]:
import detectors
import timm
import torchvision.models as models


def load_model():
    teacher_model = timm.create_model("resnet34_cifar100", pretrained=True)
    student_model = models.efficientnet_b0(num_classes = 10)
    
    generator_model = DCGENERATOR()

    return teacher_model,student_model,generator_model


In [105]:
teacher_model,student_model,generator_model = load_model()

In [81]:
from torchinfo import summary
summary(model = generator_model)

Layer (type:depth-idx)                   Param #
DCGENERATOR                              --
├─Sequential: 1-1                        --
│    └─ConvTranspose2d: 2-1              409,856
│    └─BatchNorm2d: 2-2                  512
│    └─LeakyReLU: 2-3                    --
│    └─ConvTranspose2d: 2-4              524,416
│    └─BatchNorm2d: 2-5                  256
│    └─LeakyReLU: 2-6                    --
│    └─ConvTranspose2d: 2-7              131,136
│    └─BatchNorm2d: 2-8                  128
│    └─LeakyReLU: 2-9                    --
│    └─ConvTranspose2d: 2-10             3,075
│    └─Tanh: 2-11                        --
Total params: 1,069,379
Trainable params: 1,069,379
Non-trainable params: 0

In [82]:
from torchvision.datasets import CIFAR100
from torch.utils.data import DataLoader
def load_cifar_test():
    transform_cifar100 = transforms.Compose(
       [
           transforms.ToTensor(),
           transforms.Normalize(
               mean=[0.5071, 0.4865, 0.4409],
               std=[0.2673, 0.2564, 0.2762]
           )
       ] 
    )
    test_cifar = CIFAR100(
        root="./data",train=False,download=True,transform=transform_cifar100
    )
    test_loader = DataLoader(
        dataset=test_cifar,batch_size=256,shuffle=True
    )
    
    return test_loader

In [83]:
from torchinfo import summary
summary(model = teacher_model)

Layer (type:depth-idx)                   Param #
ResNet                                   --
├─Conv2d: 1-1                            1,728
├─BatchNorm2d: 1-2                       128
├─ReLU: 1-3                              --
├─Identity: 1-4                          --
├─Sequential: 1-5                        --
│    └─BasicBlock: 2-1                   --
│    │    └─Conv2d: 3-1                  36,864
│    │    └─BatchNorm2d: 3-2             128
│    │    └─Identity: 3-3                --
│    │    └─ReLU: 3-4                    --
│    │    └─Identity: 3-5                --
│    │    └─Conv2d: 3-6                  36,864
│    │    └─BatchNorm2d: 3-7             128
│    │    └─ReLU: 3-8                    --
│    └─BasicBlock: 2-2                   --
│    │    └─Conv2d: 3-9                  36,864
│    │    └─BatchNorm2d: 3-10            128
│    │    └─Identity: 3-11               --
│    │    └─ReLU: 3-12                   --
│    │    └─Identity: 3-13               --
│    │  

In [84]:
from torchinfo import summary
summary(model = student_model)

Layer (type:depth-idx)                                  Param #
EfficientNet                                            --
├─Sequential: 1-1                                       --
│    └─Conv2dNormActivation: 2-1                        --
│    │    └─Conv2d: 3-1                                 864
│    │    └─BatchNorm2d: 3-2                            64
│    │    └─SiLU: 3-3                                   --
│    └─Sequential: 2-2                                  --
│    │    └─MBConv: 3-4                                 1,448
│    └─Sequential: 2-3                                  --
│    │    └─MBConv: 3-5                                 6,004
│    │    └─MBConv: 3-6                                 10,710
│    └─Sequential: 2-4                                  --
│    │    └─MBConv: 3-7                                 15,350
│    │    └─MBConv: 3-8                                 31,290
│    └─Sequential: 2-5                                  --
│    │    └─MBConv: 3-9         

In [106]:
def count_parameters(model):
    return sum(p.numel() for p in model.parameters() if p.requires_grad)


teacher_total_params = count_parameters(teacher_model)
student_model_params = count_parameters(student_model)


In [107]:
((student_model_params)/teacher_total_params) * 100

18.849882587879048

In [108]:
teacher_total_params

21328292

In [90]:
student_model_params

4020358

In [109]:
#params
LATENT_DIM= 100
BATCH_SIZE = 129
K = 5
num_steps = 50_000


In [110]:
#optimzer 
generator_optimizer = torch.optim.Adam(generator_model.parameters(),lr = 1e-3,betas=(0.9, 0.999))
student_optimizer  = torch.optim.SGD(student_model.parameters(),lr = 0.1,momentum=0.9)


In [None]:
#training loop 
teacher_model.eval()
for step in range(num_steps):
    
    for _ in range(K):
        student_model.train()
        latent_vec = torch.randn(BATCH_SIZE,LATENT_DIM).cuda()
        x_fake = generator_model(latent_vec)
        with torch.no_grad():
            t_logits = teacher_model(x_fake)
        s_logits = student_model(x_fake)
        loss_im = torch.mean(torch.abs(t_logits - s_logits))
        student_optimizer.zero_grad()
        loss_im.backward()
        student_optimizer.step()

    # genration stage
    
    latent_vec = torch.randn(BATCH_SIZE,LATENT_DIM).cuda()
    x_fake = generator_model(latent_vec)
    t_logits = teacher_model(x_fake)
    s_logits = student_model(x_fake)
    loss_im = torch.mean(torch.abs(t_logits - s_logits))
    generator_loss = -torch.log(loss_im+1)
    generator_optimizer.zero_grad()
    generator_loss.backward()
    generator_optimizer.step()


    if step % 100 == 0:
        print(f"Step {step}: Imitation loss={loss_im.item():.4f}, Gen loss={generator_loss.item():.4f}")


In [None]:
test_loader = load_cifar_test()
def evaluate(model):
    model.eval()
    correct = total = 0
    with torch.no_grad():
        for images, labels in testloader:
            images, labels = images.cuda(), labels.cuda()
            outputs = model(images)
            _, predicted = outputs.max(1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()
    return correct / total

acc_student1 = evaluate(student1)
print(f"Student1 Test Accuracy: {acc_student1:.4f}")