In [2]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision
from torchvision import transforms
from torchvision.utils import save_image
from torch.optim.lr_scheduler import StepLR
import argparse
import copy

In [3]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Dataset
transform_train = transforms.Compose([
    transforms.RandomCrop(32, padding=4),
    transforms.RandomHorizontalFlip(),  
    transforms.ToTensor(),
    transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
])
transform_test = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
])
trainset = torchvision.datasets.CIFAR10(root='./CIFAR10', train=True, download=True, transform=transform_train)
testset = torchvision.datasets.CIFAR10(root='./CIFAR10', train=False, download=True, transform=transform_test)
test_loader_this  = torch.utils.data.DataLoader(testset, batch_size=1000, shuffle=False, num_workers=2)


def seed_torch(seed=0):
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.benchmark = False
    torch.backends.cudnn.deterministic = True

Files already downloaded and verified
Files already downloaded and verified


In [4]:
class gamma_layer(nn.Module):

    def __init__(self, input_channel, output_channel):
        super(gamma_layer, self).__init__()
        self.H = nn.Parameter(torch.ones(output_channel, input_channel))
        self.b = nn.Parameter(torch.ones(output_channel))
        self.H.data.normal_(0, 0.1)
        self.b.data.normal_(0, 0.001)

    def forward(self, x):
        H = torch.abs(self.H)
        x = F.linear(x,H)
        return torch.tanh(x)

In [None]:
intermediate_dim = 64
threshold = 0.0001
beta = 0.001
test_batch_size=1000
channel_noise_arg = 0.5
batch_size = 128
epochs = 10
lr = 0.001
gamma = 0.5
weights = 'MNIST_model_dim:64_beta:0.001_accuracy:85.7180_model.pth'
decay_step = 60

: 

In [6]:
class gamma_function(nn.Module):

    def __init__(self):
        super(gamma_function, self).__init__()
        self.f1 = gamma_layer(1,16)
        self.f2 = gamma_layer(16,16)
        self.f3 = gamma_layer(16,16)
        self.f4 = gamma_layer(16,intermediate_dim)
        
    def forward(self, x):
        x = self.f1(x)
        x = self.f2(x)
        x = self.f3(x)
        x = self.f4(x)
        return x

In [7]:
def get_reduced_labels(target):
    # Define the mapping from original labels to reduced labels
    reduced_labels = target.clone()
    label_mapping = {
        0: 0,  # Airplane
        1: 1,  # Automobile
        2: 2,  # Bird
        3: 3,  # Cat
        4: 4,  # Deer
        5: 3,  # Dog
        6: 5,  # Frog
        7: 4,  # Horse
        8: 0,  # Ship
        9: 1   # Truck
    }
    
    for original, reduced in label_mapping.items():
        reduced_labels[target == original] = reduced
        
    return reduced_labels
def create_mask_matrix():
    mask_matrix = torch.zeros((6, 10))  # 6 reduced labels, 10 final labels
    # Define the valid final labels for each reduced label
    mask_matrix[0, [0, 8]] = 1  # Reduced label 0 maps to final labels 0 (Airplane) and 8 (Ship)
    mask_matrix[1, [1, 9]] = 1  # Reduced label 1 maps to final labels 1 (Automobile) and 9 (Truck)
    mask_matrix[2, [2]] = 1     # Reduced label 2 maps to final label 2 (Bird)
    mask_matrix[3, [3, 5]] = 1  # Reduced label 3 maps to final labels 3 (Cat) and 5 (Dog)
    mask_matrix[4, [4, 7]] = 1  # Reduced label 4 maps to final labels 4 (Deer) and 7 (Horse)
    mask_matrix[5, [6]] = 1     # Reduced label 5 maps to final label 6 (Frog)
    return mask_matrix

mask_matrix = create_mask_matrix().to(device)


In [8]:
class Flatten(nn.Module):
    def forward(self, x): return x.view(x.size(0), x.size(1))

class Mul(nn.Module):
    def __init__(self, weight):
        super().__init__()
        self.weight = weight
    def __call__(self, x): 
        return x*self.weight

class Net(nn.Module):
    def __init__(self):
        super().__init__()

        self.hidden_channel = intermediate_dim
        self.mask_matrix = create_mask_matrix().to(device)

        self.prep = nn.Sequential(
                    nn.Conv2d(3,64,kernel_size = 3,stride = 1, padding = 1, bias = False),
                    nn.BatchNorm2d(64),
                    nn.ReLU()
                    )
        self.layer1 = nn.Sequential(
                    nn.Conv2d(64,128,kernel_size = 3,stride = 1, padding = 1, bias = False),
                    nn.BatchNorm2d(128),
                    nn.ReLU(),
                    nn.MaxPool2d(kernel_size = 2, stride = 2, padding = 0, dilation = 1, ceil_mode = False)
                    )
        self.layer1_res = nn.Sequential(
                    nn.Conv2d(128,128,kernel_size = 3,stride = 1, padding = 1, bias = False),
                    nn.BatchNorm2d(128),
                    nn.ReLU(),
                    nn.Conv2d(128,128,kernel_size = 3,stride = 1, padding = 1, bias = False),
                    nn.BatchNorm2d(128),
                    nn.ReLU()
                    )
        self.layer2 = nn.Sequential(
                    nn.Conv2d(128,256,kernel_size = 3,stride = 1, padding = 1, bias = False),
                    nn.BatchNorm2d(256),
                    nn.ReLU(),
                    nn.MaxPool2d(kernel_size = 2, stride = 2)
                    )
        self.layer3 = nn.Sequential(
                    nn.Conv2d(256,512,kernel_size = 3,stride = 1, padding = 1, bias = False),
                    nn.BatchNorm2d(512),
                    nn.ReLU(),
                    nn.MaxPool2d(kernel_size = 2, stride = 2, padding = 0, dilation = 1, ceil_mode = False)
                    )
        self.layer3_res = nn.Sequential(
                    nn.Conv2d(512,512,kernel_size = 3,stride = 1, padding = 1, bias = False),
                    nn.BatchNorm2d(512),
                    nn.ReLU(),
                    nn.Conv2d(512,512,kernel_size = 3,stride = 1, padding = 1, bias = False),
                    nn.BatchNorm2d(512),
                    nn.ReLU()
                    )

        self.classifier1 = nn.Sequential(
                    nn.MaxPool2d(kernel_size = 4, stride = 4, padding = 0, dilation = 1, ceil_mode = False),
                    Flatten()
                    )
        self.classifier_reduced = nn.Linear(512, 6)  # Reduced classifier
        self.classifier_final = nn.Linear(512, 10)  # Final classifier
        # self.classifier2 = nn.Sequential(
        #             nn.Linear(512,10,bias = False),
        #             Mul(0.125)
        #             )
        
        self.encoder1 = nn.Sequential(
                        nn.Conv2d(512,4,kernel_size = 3,stride = 1, padding = 1, bias = False),
                        nn.BatchNorm2d(4),
                        nn.ReLU()
                        )
        self.encoder2 = nn.Sequential(
                        nn.Linear(64,64),
                        nn.Sigmoid()
                        )

        self.encoder3_weight = nn.Parameter(torch.Tensor(self.hidden_channel, 64))
        self.encoder3_bias = nn.Parameter(torch.Tensor(self.hidden_channel))
        self.encoder3_weight.data.normal_(0, 0.5)
        self.encoder3_bias.data.normal_(0, 0.1)

        self.decoder1 = nn.Linear(self.hidden_channel,64)
        self.decoder1_2 = nn.Sequential(
                        nn.Linear(64,64),
                        nn.ReLU()
                        )
        self.decoder1_2_2 = nn.Sequential(
                        nn.Linear(1,16),
                        nn.ReLU(),
                        nn.Linear(16,16),
                        nn.ReLU(),
                        nn.Linear(16,16),
                        nn.ReLU()
                        )
        self.decoder1_3 = nn.Sequential(
                        nn.Linear(64+16,64),
                        nn.ReLU()
                        )
        self.decoder2 = nn.Sequential(
                        nn.Conv2d(4,512,kernel_size = 3,stride = 1, padding = 1, bias = False),
                        nn.BatchNorm2d(512),
                        nn.ReLU()
                        )

        self.Tanh = nn.Tanh()
        self.gamma_mu = gamma_function().to(device)
        self.upper_tri_matrix = torch.triu(torch.ones((intermediate_dim,intermediate_dim))).to(device)

    def forward(self, x, epoch, noise = 0.1):

        x = self.prep(x)
        x = self.layer1(x)
        res = self.layer1_res(x)
        x = res + x
        x = self.layer2(x)
        x = self.layer3(x)
        x = self.encoder1(x)
        x = torch.reshape(x,(x.size()[0],4*4*4))

        # Dynamic Channel Conditions
        if self.training:
            channel_noise = torch.rand(1)*0.27 + 0.05
        else:
            channel_noise = torch.FloatTensor([1]) * noise
        channel_noise = channel_noise.to(device)

        x = self.encoder2(x)
        x_norm2 = torch.norm(x,dim=1)
        x = 64 * (x.permute(1,0)/(x_norm2+1e-6)).permute(1,0)
        weight3 = F.tanh(self.encoder3_weight)
        bias3 = F.tanh(self.encoder3_bias)
        weight3 = torch.clamp(torch.abs(weight3),min = 1e-3) * torch.sign(weight3.detach())
        bias3 = torch.clamp(torch.abs(bias3),min = 1e-3) * torch.sign(bias3.detach())
        l2_norm_squared = torch.sum(weight3.pow(2),dim = 1) + bias3.pow(2)
        l2_norm = l2_norm_squared.pow(0.5)
        weight3 = (weight3.permute(1,0) / (l2_norm+1e-6)).permute(1,0)
        bias3 = bias3 / (l2_norm+1e-6)
        x = F.linear(x, weight3, bias3)

        mu = self.gamma_mu(channel_noise)
        mu = F.linear(mu, self.upper_tri_matrix)
        mu = torch.clamp(mu,min = 1e-4)
        encoded_feature = torch.tanh(x * mu)
        encoded_feature = torch.clamp(torch.abs(encoded_feature),min = 1e-2) * torch.sign(encoded_feature.detach())
        
        # KL divergence
        KL = self.KL_log_uniform(channel_noise,torch.abs(encoded_feature))

        # Gaussian channel noise
        x = encoded_feature + torch.randn_like(encoded_feature) * channel_noise

        if self.training:
            if epoch > 60:
                x = x * self.get_mask(mu,threshold = threshold)
        else:
            x = x * self.get_mask(mu,threshold = threshold)

        x = F.relu(self.decoder1(x))
        x = self.decoder1_2(x)
        noise_feature = self.decoder1_2_2(channel_noise)
        noise_feature = noise_feature.expand(x.size()[0],16)
        x = torch.cat((x,noise_feature),dim=1)
        x = self.decoder1_3(x)
        x = torch.reshape(x,(-1,4,4,4))
        decoded_feature = self.decoder2(x)
        x = self.layer3_res(decoded_feature)
        x = x + decoded_feature
        x = self.classifier1(x)
        output_reduced = self.classifier_reduced(x)
        output_final_logits = self.classifier_final(x)
        
        return F.log_softmax(output_reduced, dim=1), output_final_logits, KL * 0.1 / channel_noise

    def KL_log_uniform(self,channel_noise,encoded_feature):

        alpha = (channel_noise/encoded_feature)
        k1 = 0.63576
        k2 = 1.8732
        k3 = 1.48695
        batch_size = alpha.size(0)
        KL_term = k1 * F.sigmoid(k2 + k3 * 2 * torch.log(alpha)) - 0.5 * F.softplus(-2 * torch.log(alpha)) - k1
        return - torch.sum(KL_term) / batch_size

    def get_mask(self, mu, threshold=threshold):
        alpha = mu.detach()
        hard_mask = (alpha > threshold).float()
        return hard_mask

    def get_mask_inference(self, channel_noise, threshold = threshold):
        mu = self.gamma_mu(channel_noise)
        alpha = F.linear(mu, self.upper_tri_matrix)
        hard_mask = (alpha > threshold).float()
        return hard_mask, alpha

In [9]:
model = Net().to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=lr,  weight_decay=1e-4)
scheduler = StepLR(optimizer, step_size=decay_step, gamma=gamma)

In [10]:
def train(model=model):
    
    test_acc = 0
    pruned_dim = 0
    saved_model = {}

    for epoch in range(epochs):

        print('\nepoch:{}'.format(epoch))
        if (epoch) % 10 == 0:
            data_loader = torch.utils.data.DataLoader(trainset, batch_size=batch_size, shuffle=True,
                                                      num_workers=4, pin_memory=True)

        for i, (x, y) in enumerate(data_loader):
            x = x.to(device)
            y = y.to(device)
            model.train()
            output_reduced, output_final_logits, KL = model(x, epoch)
            
            # Get reduced labels
            target_reduced = get_reduced_labels(y)
            
            # Compute the final output based on the reduced label mask
            #batch_size = output_final_logits.size(0)
            mask = model.mask_matrix[target_reduced]
            masked_logits = output_final_logits * mask
            
            # Normalize masked logits to form a proper log_softmax
            output_final = F.log_softmax(masked_logits, dim=1)
            
            criterion = nn.CrossEntropyLoss()
            criterion = criterion.to(device)
            
            loss_reduced = criterion(output_reduced, target_reduced)
            loss_final = criterion(output_final, y)
            
            if epoch <= 20:
                loss = loss_reduced + loss_final
            else:
                anneal_ratio = min(1, (epoch - 20) / 20)
                loss = loss_reduced + loss_final + beta * KL * anneal_ratio

            if torch.isnan(loss):
                raise Exception("NaN value")

            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
        scheduler.step()

        acc, pruned_number = test(epoch, noise=0.1)

        print('Test Accuracy:', acc, 'Pruned dim:', pruned_number, 'Activated dim:', intermediate_dim - pruned_number)

        if epoch > 7:
            if (acc > test_acc and pruned_number == pruned_dim) or pruned_number > pruned_dim:
                test_acc = acc
                pruned_dim = pruned_number
                saved_model = copy.deepcopy(model.state_dict())
                print('Best ckpt:', test_acc, 'pruned_number:', pruned_dim, 'beta:', beta)
                torch.save({'model': saved_model}, './CIFAR_model.pth')
    print('Best Accuracy:', test_acc, 'Intermediate Dim:', intermediate_dim, 'Beta:', beta)
    torch.save({'model': saved_model}, './CIFAR_model_model.pth')


In [11]:
def test(epoch, noise=0.1):
    with torch.no_grad():
        model.eval()
        correct = 0
        total = 0
        for i, (images, labels) in enumerate(test_loader_this): 
            images = images.to(device)
            labels = labels.to(device)
            output_reduced, output_final_logits, _ = model(images, epoch, noise)
            
            # Get reduced labels
            target_reduced = get_reduced_labels(labels)
            
            # Compute the final output based on the reduced label mask
            batch_size = output_final_logits.size(0)
            mask = model.mask_matrix[target_reduced]
            masked_logits = output_final_logits * mask
            
            # Normalize masked logits to form a proper log_softmax
            output_final = F.log_softmax(masked_logits, dim=1)
            
            # Evaluate based on the final output
            _, predicted = torch.max(output_final.data, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()
        
        hard_mask, mu = model.get_mask_inference(torch.FloatTensor([noise]).to(device))
        index = torch.nonzero(torch.lt(hard_mask, 0.5)).squeeze(1)
        pruned_number = index.size()[0]
        return 100 * correct / total, pruned_number


In [12]:
train()


epoch:0
Test Accuracy: 72.18 Pruned dim: 0 Activated dim: 64

epoch:1
Test Accuracy: 78.89 Pruned dim: 0 Activated dim: 64

epoch:2
Test Accuracy: 81.06 Pruned dim: 0 Activated dim: 64

epoch:3
Test Accuracy: 84.86 Pruned dim: 0 Activated dim: 64

epoch:4
Test Accuracy: 87.53 Pruned dim: 0 Activated dim: 64

epoch:5
Test Accuracy: 88.56 Pruned dim: 0 Activated dim: 64

epoch:6
Test Accuracy: 90.36 Pruned dim: 0 Activated dim: 64

epoch:7
Test Accuracy: 90.23 Pruned dim: 0 Activated dim: 64

epoch:8
Test Accuracy: 90.97 Pruned dim: 0 Activated dim: 64
Best ckpt: 90.97 pruned_number: 0 beta: 0.001

epoch:9


: 

In [None]:
channel_noise_arg = 0.5
weights = 'VL-VFE/CIFAR-model.pth'
model = Net().to(device)
model.load_state_dict(torch.load(weights)['model'])
accuracy = 0
t = 20
for i in range (t):
    acc, pruned_number = test(0,channel_noise_arg)
    accuracy += acc
#print('Noise level:',args.channel_noise,'Test Accuracy:',accuracy/t,'Activated dim:', args.intermediate_dim - pruned_number)
print('Noise level:',channel_noise_arg, 'Test Accuracy:', accuracy/t, 'Pruned dim:', pruned_number, 'Activated dim:', intermediate_dim - pruned_number)

: 