In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision
import torch.nn.functional as F
import torchvision.transforms as transforms

# Define the transformations to be applied to the dataset
transform_train = transforms.Compose([
    transforms.RandomCrop(32, padding=4),
    transforms.RandomHorizontalFlip(),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])

transform_test = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])

# Load the CIFAR-100 dataset
with torch.no_grad():
   

  trainset = torchvision.datasets.CIFAR100(root='./data', train=True, download=True, transform=transform_train)
  trainloader = torch.utils.data.DataLoader(trainset, batch_size=128, shuffle=True, num_workers=0)

  testset = torchvision.datasets.CIFAR100(root='./data', train=False, download=True, transform=transform_test)
  testloader = torch.utils.data.DataLoader(testset, batch_size=128, shuffle=False, num_workers=0)


In [None]:
class ResBlock(nn.Module):
    def __init__(self, in_channels, out_channels, stride=1):
        super(ResBlock, self).__init__()
        self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=stride, padding=1, bias=False)
        self.bn1 = nn.BatchNorm2d(out_channels)
        self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1, bias=False)
        self.bn2 = nn.BatchNorm2d(out_channels)
        self.relu = nn.ReLU(inplace=True)
        
        # If the input and output channels don't match, use a 1x1 convolution to match the dimensions
        if in_channels != out_channels:
            self.shortcut = nn.Sequential(
                nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=stride, bias=False),
                nn.BatchNorm2d(out_channels)
            )
        else:
            self.shortcut = nn.Identity()

    def forward(self, x):
        residual = x
        
        out = self.conv1(x)
        out = self.bn1(out)
        out = self.relu(out)

        out = self.conv2(out)
        out = self.bn2(out)
        
        out += self.shortcut(residual)
        out = self.relu(out)

        return out

class ResNet(nn.Module):
    def __init__(self, num_classes=100):
        super(ResNet, self).__init__()
        self.conv1 = nn.Conv2d(3, 64, kernel_size=3, stride=1, padding=1, bias=False)
        self.bn1 = nn.BatchNorm2d(64)
        self.relu = nn.ReLU(inplace=True)
        self.layer1 = nn.Sequential(
            ResBlock(64, 64, stride=1),
            ResBlock(64, 64, stride=1)
        )
        self.layer2 = nn.Sequential(
            ResBlock(64, 128, stride=2),
            ResBlock(128, 128, stride=1)
        )
        self.layer3 = nn.Sequential(
            ResBlock(128, 256, stride=2),
            ResBlock(256, 256, stride=1)
        )
        self.layer4 = nn.Sequential(
            ResBlock(256, 512, stride=2),
            ResBlock(512, 512, stride=1)
        )
        self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
        self.fc = nn.Linear(512, num_classes)

    def forward(self, x):
        x = self.conv1(x)
        x = self.bn1(x)
        x = self.relu(x)
        
        x = self.layer1(x)
        x = self.layer2(x)
        x = self.layer3(x)
        x = self.layer4(x)
        
        x = self.avgpool(x)
        x = x.view(x.size(0), -1)
        x = self.fc(x)
        return x
net = ResNet()

In [None]:
criterion = nn.CrossEntropyLoss()
optimizer = optim.SGD(net.parameters(), lr=0.01, momentum=0.9)


In [None]:
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
net.to(device)
device

In [None]:
# device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
# net.to(device)

for epoch in range(10):
    running_loss = 0.0
    for i, data in enumerate(trainloader, 0):
        inputs, labels = data[0].to(device), data[1].to(device)

        optimizer.zero_grad()

        outputs = net(inputs)
        loss = criterion(outputs, labels)
        loss.backward()
        print(epoch,loss,i)
        optimizer.step()

        running_loss += loss.item()
        if i % 100 == 99:
            print('[%d, %5d] loss: %.3f' %
                  (epoch + 1, i + 1, running_loss / 100))
            running_loss = 0.0

print('Finished Training')


In [None]:
correct = 0
total = 0
with torch.no_grad():
    for data in testloader:
        images, labels = data[0].to(device), data[1].to(device)
        outputs = net(images)
        _, predicted = torch.max(outputs.data, 1)
        total += labels.size(0)
        correct += (predicted == labels).sum().item()

print('Accuracy of the network on the 10000 test images: %d %%' % (
    100 * correct / total))


In [None]:
class ResNetGumbel(nn.Module):
    def __init__(self, num_classes):
        super(ResNetGumbel, self).__init__()
        self.conv1 = nn.Conv2d(3, 64, kernel_size=3, stride=1, padding=1, bias=False)
        self.bn1 = nn.BatchNorm2d(64)
        self.relu = nn.ReLU(inplace=True)
        self.layer1 = self._make_layer(64, 2)
        self.layer2 = self._make_layer(128, 2, stride=2)
        self.layer3 = self._make_layer(256, 2, stride=2)
        self.layer4 = self._make_layer(512, 2, stride=2)
        self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
        self.fc = nn.Linear(512, num_classes)

    def _make_layer(self, planes, blocks, stride=1):
        downsample = None
        if stride != 1 or planes != 64:
            downsample = nn.Sequential(
                nn.Conv2d(64, planes, kernel_size=1, stride=stride, bias=False),
                nn.BatchNorm2d(planes),
            )

        layers = []
        layers.append(ResNetBlock(64, planes, stride, downsample))
        for i in range(1, blocks):
            layers.append(ResNetBlock(planes, planes))

        return nn.Sequential(*layers)

    def forward(self, x):
        x = self.conv1(x)
        x = self.bn1(x)
        x = self.relu(x)
        x = self.layer1(x)
        x = self.layer2(x)
        x = self.layer3(x)
        x = self.layer4(x)
        x = self.avgpool(x)
        x = x.view(x.size(0), -1)
        x = self.fc(x)
        return x

class ResNetBlock(nn.Module):
    def __init__(self, in_planes, planes, stride=1, downsample=None):
        super(ResNetBlock, self).__init__()
        self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=3, stride=stride, padding=1, bias=False)
        self.bn1 = nn.BatchNorm2d(planes)
        self.relu = nn.ReLU(inplace=True)
        self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=1, padding=1, bias=False)
        self.bn2 = nn.BatchNorm2d(planes)
        self.downsample = downsample
        self.stride = stride

    def forward(self, x):
        identity = x

        out = self.conv1(x)
        out = self.bn1(out)
        out = self.relu(out)

        out = self.conv2(out)
        out = self.bn2(out)

        if self.downsample is not None:
            identity = self.downsample(x)

        out += identity
        out = self.relu(out)

        return out

class GumbelSoftmax(nn.Module):
    def __init__(self, temperature=1.0):
        super(GumbelSoftmax, self).__init__()
        self.temperature = temperature

    def forward(self, logits):
        gumbels = -torch.log(-torch.log(torch.rand_like(logits)))
        noisy_logits = (logits + gumbels) / self.temperature
        sampled_logits = F.softmax(noisy_logits, dim=-1)
        return sampled_logits


In [None]:
class GumbelSoftmax(nn.Module):
    def __init__(self, input_dim, hidden_dim, n_classes, temp=1.0, dropout_prob=0.5):
        super(GumbelSoftmax, self).__init__()
        self.fc1 = nn.Linear(input_dim, hidden_dim)
        self.fc2 = nn.Linear(hidden_dim, hidden_dim)
        self.fc3 = nn.Linear(hidden_dim, n_classes)
        self.temp = temp
        self.dropout = nn.Dropout(p=dropout_prob)
        
    def forward(self, x):
        x = x.view(-1, input_dim)
        x = F.relu(self.fc1(x))
        x = self.dropout(x)
        x = F.relu(self.fc2(x))
        x = self.dropout(x)
        logits = self.fc3(x)
        y = self.gumbel_softmax(logits, self.temp)
        return y
    
    def gumbel_softmax(self, logits, temperature):
        g = Gumbel(torch.tensor([0.0]), torch.tensor([1.0]))
        noise = g.sample(logits.size()).squeeze().to(device)
        y = F.softmax((logits + noise) / temperature, dim=-1)
        return y


In [None]:
input_dim = 3 * 32 * 32
hidden_dim = 512
n_classes = 100
n_layers = 3
lr = 0.01
batch_size = 64
num_epochs = 10
temperature = 1.0


In [None]:
model = GumbelSoftmax(input_dim=3072, hidden_dim=hidden_dim, n_classes=n_classes, temp=0.5).to(device)


In [None]:
model = model.to('cuda')
images = images.to('cuda')
labels = labels.to('cuda')


In [None]:
# Define the optimizer and loss function
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = GumbelSoftmax(input_dim=3072,hidden_dim=hidden_dim, n_classes=n_classes, temp=0.5).to(device)
optimizer = optim.SGD(model.parameters(), lr=0.1)
criterion = nn.CrossEntropyLoss()

# Train the model on some data
for epoch in range(num_epochs):
    train_loss = 0.0
    train_acc = 0.0
    for images, labels in trainloader:
        images = images.view(images.size(0), -1).to(device)
        labels = labels.to(device)
        # Forward pass
        logits = model(images)
        loss = criterion(logits, labels)

        # Backward pass and optimization step
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        train_loss += loss.item() * images.size(0)
        train_acc += (logits.argmax(dim=1) == labels).sum().item()

    train_loss /= len(trainloader.dataset)
    train_acc /= len(trainloader.dataset)

    # Evaluate the model on the test set
    test_loss = 0.0
    test_acc = 0.0
    with torch.no_grad():
        for images, labels in testloader:
            images = images.view(images.size(0), -1).to(device)
            labels = labels.to(device)
            logits = model(images)
            test_loss += criterion(logits, labels).item() * images.size(0)
            test_acc += (logits.argmax(dim=1) == labels).sum().item()

    test_loss /= len(testloader.dataset)
    test_acc /= len(testloader.dataset)

    # Print the epoch number, train and test losses, and train and test accuracies
    print(f"Epoch [{epoch+1}/{num_epochs}], Train Loss: {train_loss:.4f}, Train Acc: {train_acc:.4f}, Test Loss: {test_loss:.4f}, Test Acc: {test_acc:.4f}")


In [None]:
import numpy as np
from sklearn.metrics import accuracy_score, precision_score, recall_score, f1_score, confusion_matrix

def evaluate_model(model, dataloader):
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    model.to(device)
    model.eval()

    true_labels = []
    predicted_labels = []

    with torch.no_grad():
        for inputs, labels in dataloader:
            inputs, labels = inputs.to(device), labels.to(device)
            outputs = model(inputs)
            _, predicted = torch.max(outputs.data, 1)
            true_labels += labels.cpu().numpy().tolist()
            predicted_labels += predicted.cpu().numpy().tolist()

    accuracy = accuracy_score(true_labels, predicted_labels)
    precision = precision_score(true_labels, predicted_labels, average='weighted')
    recall = recall_score(true_labels, predicted_labels, average='weighted')
    f1 = f1_score(true_labels, predicted_labels, average='weighted')
    cm = confusion_matrix(true_labels, predicted_labels)

    return accuracy, precision, recall, f1, cm


In [None]:
test_acc, test_prec, test_rec, test_f1, test_cm = evaluate_model(model, testloader)
print("Standard Softmax:")
print(f"Test Accuracy: {test_acc:.4f}")
print(f"Test Precision: {test_prec:.4f}")
print(f"Test Recall: {test_rec:.4f}")
print(f"Test F1 Score: {test_f1:.4f}")
print("Confusion Matrix:")
print(test_cm)

test_acc, test_prec, test_rec, test_f1, test_cm = evaluate_model(model, testloader)
print("Gumbel-Softmax:")
print(f"Test Accuracy: {test_acc:.4f}")
print(f"Test Precision: {test_prec:.4f}")
print(f"Test Recall: {test_rec:.4f}")
print(f"Test F1 Score: {test_f1:.4f}")
print("Confusion Matrix:")
print(test_cm)
