In [2]:
import torch
from matplotlib import pyplot as plt
import torch.nn as nn
import torch.optim as optim
import torchvision
import torchvision.transforms as transforms
from torchvision.models import resnet18
from tqdm.notebook import tqdm
import numpy as np
import torch.nn.functional as F
import os
import pickle
from attack_pgd import pgd
from torch.nn.parallel import DataParallel
import logging
from torch.utils.tensorboard import SummaryWriter
from torchcontrib.optim import SWA
# Set the device
os.environ["CUDA_VISIBLE_DEVICES"] = "0,1,2,3"
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Set the random seed for reproducibility
torch.manual_seed(42)

# Define constants
UNLABELED_BS = 50
TRAIN_BS = 50
TEST_BS = 128
num_train_samples = 5000
samples_per_class = int(num_train_samples / 10)

num_validation_samples = int(num_train_samples * 0.1)


# Define transformations for data augmentation
transform_train = transforms.Compose([
    transforms.RandomCrop(32, padding=4),
    transforms.RandomHorizontalFlip(),
    transforms.ToTensor(),
    transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010))
])

transform_test = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010))
])

# Load CIFAR-10 dataset
trainset = torchvision.datasets.CIFAR10(root='./data', train=True, download=True, transform=transform_train)
testset = torchvision.datasets.CIFAR10(root='./data', train=False, download=True, transform=transform_test)

train_loader = torch.utils.data.DataLoader(trainset, batch_size=TRAIN_BS, shuffle=True, num_workers=2)
test_loader = torch.utils.data.DataLoader(testset, batch_size=TEST_BS, shuffle=False, num_workers=2)

# Create labeled and unlabeled datasets
labeled_indices = []
unlabeled_indices = []

# Split the trainset into labeled and unlabeled data
for class_index in range(10):
    class_indices = torch.where(torch.Tensor(trainset.targets) == class_index)[0]
    labeled_indices.extend(class_indices[:samples_per_class])
    unlabeled_indices.extend(class_indices[samples_per_class:])

labeled_dataset = torch.utils.data.Subset(trainset, labeled_indices)
unlabeled_dataset = torch.utils.data.Subset(trainset, unlabeled_indices)

# Create data loaders for labeled and unlabeled data
labeled_loader = torch.utils.data.DataLoader(labeled_dataset, batch_size=TRAIN_BS, shuffle=True, num_workers=2)
unlabeled_loader = torch.utils.data.DataLoader(unlabeled_dataset, batch_size=UNLABELED_BS, shuffle=True, num_workers=2)

#labeled_train_subset, labeled_validation_subset = torch.utils.data.random_split(
#    labeled_dataset, [num_train_samples - num_validation_samples, num_validation_samples])

#labeled_loader = torch.utils.data.DataLoader(
#    labeled_train_subset, batch_size=TRAIN_BS, shuffle=True, num_workers=2)

#labeled_validation_loader = torch.utils.data.DataLoader(
#    labeled_validation_subset, batch_size=TEST_BS, shuffle=False, num_workers=2)


Files already downloaded and verified
Files already downloaded and verified


In [3]:
import torch
import torch.nn as nn
import torchvision.transforms as transforms
from torch.nn import functional as F
from torchvision.datasets import CIFAR10
from torch.utils.data import DataLoader
from torch.nn.utils import weight_norm

import sys
import math
import itertools
from torch.autograd import Variable, Function

In [4]:
# For cifar_cnn
class CNN(nn.Module):
    """
    CNN from Mean Teacher paper
    """

    def __init__(self, num_classes=10, isL2 = False, dropRatio = 0.0):
        super(CNN, self).__init__()

        self.isL2 = isL2

        self.activation = nn.LeakyReLU(0.1)
        self.conv1a = weight_norm(nn.Conv2d(3, 128, 3, padding=1))
        self.bn1a = nn.BatchNorm2d(128)
        self.conv1b = weight_norm(nn.Conv2d(128, 128, 3, padding=1))
        self.bn1b = nn.BatchNorm2d(128)
        self.conv1c = weight_norm(nn.Conv2d(128, 128, 3, padding=1))
        self.bn1c = nn.BatchNorm2d(128)
        self.mp1 = nn.MaxPool2d(2, stride=2, padding=0)
        # self.drop1  = nn.Dropout(0.5)
        # self.drop1  = nn.Dropout(dropRatio)
        self.drop  = nn.Dropout(dropRatio)

        self.conv2a = weight_norm(nn.Conv2d(128, 256, 3, padding=1))
        self.bn2a = nn.BatchNorm2d(256)
        self.conv2b = weight_norm(nn.Conv2d(256, 256, 3, padding=1))
        self.bn2b = nn.BatchNorm2d(256)
        self.conv2c = weight_norm(nn.Conv2d(256, 256, 3, padding=1))
        self.bn2c = nn.BatchNorm2d(256)
        self.mp2 = nn.MaxPool2d(2, stride=2, padding=0)
        # self.drop2  = nn.Dropout(0.5)
        # self.drop2  = nn.Dropout(dropRatio)

        self.conv3a = weight_norm(nn.Conv2d(256, 512, 3, padding=0))
        self.bn3a = nn.BatchNorm2d(512)
        self.conv3b = weight_norm(nn.Conv2d(512, 256, 1, padding=0))
        self.bn3b = nn.BatchNorm2d(256)
        self.conv3c = weight_norm(nn.Conv2d(256, 128, 1, padding=0))
        self.bn3c = nn.BatchNorm2d(128)
        self.ap3 = nn.AvgPool2d(6, stride=2, padding=0)

        self.fc1 =  weight_norm(nn.Linear(128, num_classes))
        self.fc2 =  weight_norm(nn.Linear(128, num_classes))

    def forward(self, x, debug=False):
        x = self.activation(self.bn1a(self.conv1a(x)))
        x = self.activation(self.bn1b(self.conv1b(x)))
        x = self.activation(self.bn1c(self.conv1c(x)))
        x = self.mp1(x)
        x = self.drop(x)

        x = self.activation(self.bn2a(self.conv2a(x)))
        x = self.activation(self.bn2b(self.conv2b(x)))
        x = self.activation(self.bn2c(self.conv2c(x)))
        x = self.mp2(x)
        x = self.drop(x)

        x = self.activation(self.bn3a(self.conv3a(x)))
        x = self.activation(self.bn3b(self.conv3b(x)))
        x = self.activation(self.bn3c(self.conv3c(x)))
        x = self.ap3(x)

        x = x.view(-1, 128)
        if self.isL2:
            x = F.normalize(x)
        # return self.fc1(x), self.fc2(x), x
        return self.fc1(x)#, self.fc2(x), x

        

In [None]:
"""
from models.wideresnet import WideResNet
import torch
from torch.nn import Sequential, Module

def get_model(name, num_classes=10, normalize_input=False):
    name_parts = name.split('-')
    if name_parts[0] == 'wrn':
        depth = int(name_parts[1])
        widen = int(name_parts[2])
        model = WideResNet(
            depth=depth, num_classes=num_classes, widen_factor=widen)
    if normalize_input:
        model = Sequential(NormalizeInput(), model)

    return model

class NormalizeInput(Module):
    def __init__(self, mean=(0.4914, 0.4822, 0.4465),
                 std=(0.2023, 0.1994, 0.2010)):
        super().__init__()

        self.register_buffer('mean', torch.Tensor(mean).reshape(1, -1, 1, 1))
        self.register_buffer('std', torch.Tensor(std).reshape(1, -1, 1, 1))

    def forward(self, x):
        return (x - self.mean) / self.std
"""

In [5]:
from __future__ import print_function
import os
import argparse
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision
import torch.optim as optim
from torchvision import datasets, transforms

#from models.wideresnet import *
#from models.resnet import *
#from losses import trades_loss

In [6]:
model_dir = './rst_adv/combined_loader/13CNN/mixup'
if not os.path.exists(model_dir):
    os.makedirs(model_dir)

logging.basicConfig(
    level=logging.INFO,
    format="%(asctime)s | %(message)s",
    handlers=[
        logging.FileHandler(os.path.join(model_dir, 'fly_soft_SWA70_CNN_lowBS_50_50_50k_training.log')),
        logging.StreamHandler()
    ])
logger = logging.getLogger()

In [7]:
##############################################################################

def loss_soft_reg_ep(preds, labels, soft_labels, device):
    num_classes = 10
    prob = F.softmax(preds, dim=1)
    prob_avg = torch.mean(prob, dim=0)
    p = torch.ones(num_classes).to(device) / args.num_classes

    L_c = -torch.mean(torch.sum(soft_labels * F.log_softmax(preds, dim=1), dim=1))   # Soft labels
    L_p = -torch.sum(torch.log(prob_avg) * p)
    L_e = -torch.mean(torch.sum(prob * F.log_softmax(preds, dim=1), dim=1))

    loss = L_c + 0.8 * L_p + 0.4 * L_e
    return prob, loss

##############################################################################
def mixup_data(x, y, alpha=1.0, device='cuda'):
    '''Returns mixed inputs, pairs of targets, and lambda'''
    if alpha > 0:
        lam = np.random.beta(alpha, alpha)
    else:
        lam = 1

    batch_size = x.size()[0]
    if device=='cuda':
        index = torch.randperm(batch_size).to(device)
    else:
        index = torch.randperm(batch_size)

    mixed_x = lam * x + (1 - lam) * x[index, :]
    y_a, y_b = y, y[index]
    return mixed_x, y_a, y_b, lam





In [8]:
"""
def loss_mixup_reg_ep(preds, labels, targets_a, targets_b, device, lam):
    num_classes = 10
    prob = F.softmax(preds, dim=1)
    prob_avg = torch.mean(prob, dim=0)
    p = torch.ones(num_classes).to(device) / num_classes

    mixup_loss_a = -torch.mean(torch.sum(targets_a * F.log_softmax(preds, dim=1), dim=1))
    mixup_loss_b = -torch.mean(torch.sum(targets_b * F.log_softmax(preds, dim=1), dim=1))
    mixup_loss = lam * mixup_loss_a + (1 - lam) * mixup_loss_b

    L_p = -torch.sum(torch.log(prob_avg) * p)
    L_e = -torch.mean(torch.sum(prob * F.log_softmax(preds, dim=1), dim=1))

    loss = mixup_loss + 0.8 * L_p + 0.4 * L_e
    return prob, loss
"""

In [12]:

def loss_mixup_reg_ep(preds, labels, targets_a, targets_b, device, lam):
    num_classes = preds.size(1)  # Get the number of classes from preds
    prob = F.softmax(preds, dim=1)
    prob_avg = torch.mean(prob, dim=0)
    p = torch.ones(num_classes).to(device) / num_classes

    # Ensure that targets_a and targets_b have the correct shape
    targets_a_onehot = F.one_hot(targets_a, num_classes=num_classes).float()
    targets_b_onehot = F.one_hot(targets_b, num_classes=num_classes).float()

    mixup_loss_a = -torch.mean(torch.sum(targets_a_onehot * F.log_softmax(preds, dim=1), dim=1))
    mixup_loss_b = -torch.mean(torch.sum(targets_b_onehot * F.log_softmax(preds, dim=1), dim=1))
    mixup_loss = lam * mixup_loss_a + (1 - lam) * mixup_loss_b

    L_p = -torch.sum(torch.log(prob_avg) * p)
    L_e = -torch.mean(torch.sum(prob * F.log_softmax(preds, dim=1), dim=1))

    loss = mixup_loss + 0.8 * L_p + 0.4 * L_e
    return prob, loss


In [13]:

step_size= 0.007
beta = 6
num_steps = 10
epsilon=0.031
#lr = 0.1
batch_size= 128
test_batch_size= 128
epochs = 150
weight_decay= 2e-4
momentum = 0.9
seed = 1
log_interval= 100
#model_dir= './rst_adv/trades'
distance = 'l_inf'
eval_attack_batches = 1
pgd_num_steps = 10
pgd_step_size = 0.007

swa_start = 70
swa_freq = 5


def train_new(model, device, labeled_loader, unlabeled_loader, optimizer, swa_optimizer, epoch):    

    #loss_ema = 0.
    alpha=1
    labeled_iter = iter(labeled_loader)    
    num_iter = (len(unlabeled_loader.dataset)//TRAIN_BS)+1
    for batch_idx, (x_unlabeled, _) in enumerate(unlabeled_loader):      
        try:
            x_labeled, y_labeled = next(labeled_iter)
        except StopIteration:
            labeled_iter = iter(labeled_loader)
            x_labeled, y_labeled = next(labeled_iter) 
        model.eval()
        with torch.no_grad():
            x_unlabeled = x_unlabeled.to(device)
            #pseudo_labels = model(x_unlabeled.to(device)).argmax(dim=1)
            #logits_unlabeled = model(x_unlabeled.to(device))
            # Softmax normalization to obtain probabilities for each class
            #pseudo_labels = F.softmax(logits_unlabeled, dim=1)
            
            logits_pseudo_labels = model(x_unlabeled.to(device))
            pseudo_labels = F.softmax(logits_pseudo_labels, dim=1)
            
        
        model.train()
        
        # concatenate (x_labeled, y_labeled) and (x_unlabeled, pseudo_labels)
        x_labeled = x_labeled.to(device)
        y_labeled = y_labeled.to(device)
        #print(y_labeled.shape)
        #print(y_labeled[0])
        # Add a new dimension to y_labeled to make it compatible for concatenation
        num_classes = 10
        y_labeled_onehot = F.one_hot(y_labeled, num_classes).float()
        #print(y_labeled_onehot.shape)
        #print(y_labeled_onehot[0])
        pseudo_labels = pseudo_labels.to(device)
        y_labeled_onehot = y_labeled_onehot.to(device)
        #print(pseudo_labels.shape)
        
        data = torch.cat([x_labeled,x_unlabeled],dim=0)
        target = torch.cat([y_labeled_onehot,pseudo_labels],dim=0)
        #target = torch.cat([y_labeled,pseudo_labels],dim=0)
        
        #if epoch < swa_start:
            #current_optimizer = optimizer
        #else:
            #current_optimizer = swa_optimizer
        
        optimizer.zero_grad()
        #output_x1 = model(data)
        #output_x1.detach_()
        #optimizer.zero_grad()
        
        images, targets_a, targets_b, lam = mixup_data(data, target, alpha, device)
        
        # calculate robust loss
        loss_trades = trades_loss(model=model,
                           x_natural=images,
                           y=target,
                           optimizer=optimizer,
                           step_size=0.007,
                           epsilon=0.031,
                           perturb_steps=10,
                           beta=6.0)
        
        outputs = model(images)
        #prob = F.softmax(output_x1, dim=1)
        prob_mixup, loss_reg = loss_mixup_reg_ep(outputs, target, targets_a, targets_b, device, lam)
        
        loss=loss_trades+loss_reg
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        #loss_ema = loss_ema * 0.9 + float(loss) * 0.1

        # print progress
        if batch_idx % log_interval == 0:
            logging.info('Train Epoch:{}\tLoss: {:.6f}'.format(epoch,loss.item()))
        #logging.info('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format(
                #epoch, batch_idx * len(data), len(train_loader.dataset),
                       #100. * batch_idx / len(train_loader), loss.item()))
    if epoch > swa_start and epoch%swa_freq == 0 :
        swa_optimizer.update_swa()
    if epoch >= swa_start:
        swa_optimizer.bn_update(train_loader, model, device)


def train(model, device, train_loader, optimizer, swa_optimizer, epoch):
    alpha=1
    model.train()
    for batch_idx, (data, target) in enumerate(train_loader):
        data, target = data.to(device), target.to(device)

        optimizer.zero_grad()
        
        images, targets_a, targets_b, lam = mixup_data(data, target, alpha, device)

        # calculate robust loss
        loss_trades = trades_loss(model=model,
                           x_natural=data,
                           y=target,
                           optimizer=optimizer,
                           step_size=step_size,
                           epsilon=epsilon,
                           perturb_steps=num_steps,
                           beta=beta)
        
        outputs = model(images)
        prob_mixup, loss_reg = loss_mixup_reg_ep(outputs, target, targets_a, targets_b, device, lam)
        
        loss=loss_trades+loss_reg
        
        loss.backward()
        optimizer.step()
        
        
        # print progress
        if batch_idx % log_interval == 0:
            logging.info('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format(
                epoch, batch_idx * len(data), len(train_loader.dataset), 100. * batch_idx / len(train_loader), loss.item()))
        
    if epoch > swa_start and epoch%swa_freq == 0 :
        swa_optimizer.update_swa()
    if epoch >= swa_start:
        swa_optimizer.bn_update(train_loader, model, device)

        
            
"""
def eval_train(model, device, train_loader):
    model.eval()
    train_loss = 0
    correct = 0
    with torch.no_grad():
        for data, target in train_loader:
            data, target = data.to(device), target.to(device)
            output = model(data)
            train_loss += F.cross_entropy(output, target, size_average=False).item()
            pred = output.max(1, keepdim=True)[1]
            correct += pred.eq(target.view_as(pred)).sum().item()
    train_loss /= len(train_loader.dataset)
    print('Training: Average loss: {:.4f}, Accuracy: {}/{} ({:.0f}%)'.format(
        train_loss, correct, len(train_loader.dataset),
        100. * correct / len(train_loader.dataset)))
    training_accuracy = correct / len(train_loader.dataset)
    return train_loss, training_accuracy
"""
def eval_train(model, device, train_loader):
    model.eval()
    train_loss = 0
    correct = 0
    loss = 0
    total = 0
    correct = 0
    adv_correct = 0
    adv_correct_clean = 0
    adv_total = 0
    print_freq = 25
    
    with torch.no_grad():
        for batch_idx, (data, target) in enumerate(train_loader):
            data, target = data.to(device), target.to(device)
            output = model(data)
            train_loss += F.cross_entropy(output, target, reduction='sum').item()
            pred = output.max(1, keepdim=True)[1]
            correct += pred.eq(target.view_as(pred)).sum().item()
            
            if batch_idx < eval_attack_batches:
                is_correct_clean, is_correct_rob = pgd(model, data, target,epsilon=epsilon,num_steps=2 * pgd_num_steps,step_size= pgd_step_size,random_start=False)
                incorrect_clean = (1-is_correct_clean).sum()
                incorrect_rob = (1-np.prod(is_correct_rob, axis=1)).sum()
            
                adv_correct_clean += (len(data) - int(incorrect_clean))
                adv_correct += (len(data) - int(incorrect_rob))
                adv_total += len(data)
            total += len(data)
    train_loss /= len(train_loader.dataset)
    if adv_total > 0:
        robust_clean_accuracy = adv_correct_clean / adv_total
        robust_accuracy = adv_correct / adv_total
    else:
        robust_accuracy = robust_clean_accuracy = 0.
    logging.info('Training: Average loss: {:.4f}, Accuracy: {}/{} ({:.0f}%),'
        'PDG clean accuracy: {}/{} ({:.2f}%), Robust accuracy {}/{} ({:.2f}%)'.format(
        train_loss, correct, len(train_loader.dataset),
        100. * correct / len(train_loader.dataset),
        adv_correct_clean, adv_total, 100.0 * robust_clean_accuracy,
        adv_correct, adv_total, 100.0 * robust_accuracy))
    training_accuracy = correct / len(train_loader.dataset)
    return train_loss, training_accuracy, robust_accuracy



def eval_test(model, device, test_loader):
    model.eval()
    test_loss = 0
    correct = 0
    total = 0
    correct = 0
    adv_correct = 0
    adv_correct_clean = 0
    adv_total = 0
    with torch.no_grad():
        for batch_idx, (data, target) in enumerate(test_loader):
            data, target = data.to(device), target.to(device)
            output = model(data)
            test_loss += F.cross_entropy(output, target, reduction='sum').item()
            pred = output.max(1, keepdim=True)[1]
            correct += pred.eq(target.view_as(pred)).sum().item()
            
            if batch_idx < eval_attack_batches:
                is_correct_clean, is_correct_rob = pgd(
                    model, data, target,
                    epsilon=epsilon,
                    num_steps=2 * pgd_num_steps,
                    step_size= pgd_step_size,
                    random_start=False)
                incorrect_clean = (1-is_correct_clean).sum()
                incorrect_rob = (1-np.prod(is_correct_rob, axis=1)).sum()
            
                adv_correct_clean += (len(data) - int(incorrect_clean))
                adv_correct += (len(data) - int(incorrect_rob))
                adv_total += len(data)
            total += len(data)
            
    test_loss /= len(test_loader.dataset)
    if adv_total > 0:
        robust_clean_accuracy = adv_correct_clean / adv_total
        robust_accuracy = adv_correct / adv_total
    else:
        robust_accuracy = robust_clean_accuracy = 0.
    logging.info('Test: Average loss: {:.4f}, Accuracy: {}/{} ({:.0f}%),'
        'PDG clean accuracy: {}/{} ({:.2f}%), Robust accuracy {}/{} ({:.2f}%)'.format(
        test_loss, correct, len(test_loader.dataset),
        100. * correct / len(test_loader.dataset),
        adv_correct_clean, adv_total, 100.0 * robust_clean_accuracy,
        adv_correct, adv_total, 100.0 * robust_accuracy))
    test_accuracy = correct / len(test_loader.dataset)
    return test_loss, test_accuracy, robust_accuracy


def adjust_learning_rate(optimizer, epoch):
    """decrease the learning rate"""
    lr = 0.01
    if epoch >= 90:
        lr = lr * 0.1
    if epoch >= 130:
        lr = lr * 0.01
    if epoch >= 140:
        lr = lr * 0.001
    for param_group in optimizer.param_groups:
        param_group['lr'] = lr

        
        
def my_collate(batch):
    """Define collate_fn myself because the default_collate_fn throws errors like crazy"""
    # item: a tuple of (img, label)
    data = [item[0] for item in batch]
    target = [item[1] for item in batch]
    data = torch.stack(data)
    target = torch.LongTensor(target)
    return [data, target]

In [14]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.autograd import Variable
import torch.optim as optim


def squared_l2_norm(x):
    flattened = x.view(x.unsqueeze(0).shape[0], -1)
    return (flattened ** 2).sum(1)


def l2_norm(x):
    return squared_l2_norm(x).sqrt()


def trades_loss(model,
                x_natural,
                y,
                optimizer,
                step_size=0.003,
                epsilon=0.031,
                perturb_steps=10,
                beta=6,
                distance='l_inf'):
    # define KL-loss
    criterion_kl = nn.KLDivLoss(reduction='sum')
    model.eval()
    batch_size = len(x_natural)
    # generate adversarial example
    x_adv = x_natural.detach() + 0.001 * torch.randn(x_natural.shape).cuda().detach()
    if distance == 'l_inf':
        for _ in range(perturb_steps):
            x_adv.requires_grad_()
            with torch.enable_grad():
                loss_kl = criterion_kl(F.log_softmax(model(x_adv), dim=1),
                                       F.softmax(model(x_natural), dim=1))
            grad = torch.autograd.grad(loss_kl, [x_adv])[0]
            x_adv = x_adv.detach() + step_size * torch.sign(grad.detach())
            x_adv = torch.min(torch.max(x_adv, x_natural - epsilon), x_natural + epsilon)
            x_adv = torch.clamp(x_adv, 0.0, 1.0)
    elif distance == 'l_2':
        delta = 0.001 * torch.randn(x_natural.shape).cuda().detach()
        delta = Variable(delta.data, requires_grad=True)

        # Setup optimizers
        optimizer_delta = optim.SGD([delta], lr=epsilon / perturb_steps * 2)

        for _ in range(perturb_steps):
            adv = x_natural + delta

            # optimize
            optimizer_delta.zero_grad()
            with torch.enable_grad():
                loss = (-1) * criterion_kl(F.log_softmax(model(adv), dim=1),
                                           F.softmax(model(x_natural), dim=1))
            loss.backward()
            # renorming gradient
            grad_norms = delta.grad.view(batch_size, -1).norm(p=2, dim=1)
            delta.grad.div_(grad_norms.view(-1, 1, 1, 1))
            # avoid nan or inf if gradient is 0
            if (grad_norms == 0).any():
                delta.grad[grad_norms == 0] = torch.randn_like(delta.grad[grad_norms == 0])
            optimizer_delta.step()

            # projection
            delta.data.add_(x_natural)
            delta.data.clamp_(0, 1).sub_(x_natural)
            delta.data.renorm_(p=2, dim=0, maxnorm=epsilon)
        x_adv = Variable(x_natural + delta, requires_grad=False)
    else:
        x_adv = torch.clamp(x_adv, 0.0, 1.0)
    model.train()

    x_adv = Variable(torch.clamp(x_adv, 0.0, 1.0), requires_grad=False)
    # zero gradient
    optimizer.zero_grad()
    # calculate robust loss
    logits = model(x_natural)
    loss_natural = F.cross_entropy(logits, y)
    loss_robust = (1.0 / batch_size) * criterion_kl(F.log_softmax(model(x_adv), dim=1),
                                                    F.softmax(model(x_natural), dim=1))
    loss = loss_natural + beta * loss_robust
    return loss

In [15]:
def main():
    # init model, ResNet18() can be also used here for training
    #os.environ["CUDA_VISIBLE_DEVICES"] = "1,2,3"
    #device_ids = [1, 2, 3]
    #device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    num_classes = 10
    swa_lr = 0.0001
    model = CNN(num_classes=10, isL2 = False, dropRatio = 0.1)
    #model = get_model('wrn-28-10', num_classes=num_classes,
    #                      normalize_input=False)
    #model = model.to(device)

    if torch.cuda.device_count() > 1:
        model = DataParallel(model)
        print("Using", torch.cuda.device_count(), "GPUs!")
        model = model.to(device)
    lr =0.01
    print_freq = 50
    optimizer = optim.SGD(model.parameters(), lr=lr, momentum=momentum, weight_decay=weight_decay)
    #optimizer = SWA(base_optimizer, swa_lr=swa_lr)
    swa_optimizer = SWA(optimizer, swa_start=swa_start, swa_freq=swa_freq, swa_lr=swa_lr)
    criterion = nn.CrossEntropyLoss()
    save_freq = 50
    sup_epochs = 150
    
    
    model.train()
    optimizer.zero_grad()
    
    # Forward pass for labeled data
    writer = SummaryWriter("runs/mixup/fly_soft_cnn_swa70_lowBS_50_50_50k_adversarial")
    for epoch in tqdm(range(sup_epochs)):
        lr = adjust_learning_rate(optimizer, epoch)
        #train(model, device, labeled_loader, optimizer, swa_optimizer, epoch)
        train(model, device, train_loader, optimizer, swa_optimizer, epoch)
        
        print('================================================================')
        cnn_swa_test_loss, cnn_swa_test_accuracy, cnn_swa_test_robust_accuracy = eval_test(model, device, test_loader)
        writer.add_scalar('fly_soft_swa70_robust_acc/test', cnn_swa_test_robust_accuracy, epoch)
        writer.add_scalar('fly_soft_swa70_clean_acc/test', cnn_swa_test_accuracy, epoch)
        print('================================================================')
        
        
        # save checkpoint
        if epoch % save_freq == 0:
            torch.save(model.state_dict(),
                       os.path.join(model_dir, 'fly_soft_swa70_lowBS_50_50_50k_model-cnn-epoch{}.pt'.format(epoch)))
            torch.save(optimizer.state_dict(),
                       os.path.join(model_dir, 'fly_soft_swa70_lowBS_50_50_50k_opt-cnn-checkpoint_epoch{}.tar'.format(epoch)))
    
    swa_optimizer.swap_swa_sgd()
    best_swa_test_loss, best_swa_test_accuracy, best_swa_test_robust_accuracy = eval_test(model, device, test_loader)
        #for x_labeled, y_labeled in labeled_loader:
        #    x_labeled, y_labeled = x_labeled.to(device), y_labeled.to(device)
        #    labeled_output = model(x_labeled)
        #    labeled_loss = criterion(labeled_output, y_labeled)

            # Backward pass and optimization
        #    labeled_loss.backward()
        #    optimizer.step()
        #    optimizer.zero_grad()

    #writer = SummaryWriter("runs/mixup/fly_soft_cnn_swa70_lowBS_50_50_1k_adversarial")
    #for epoch in tqdm(range(1, epochs + 1)):
        #lr = adjust_learning_rate(optimizer, epoch)
        #train_new(model, device, labeled_loader, unlabeled_loader, optimizer, swa_optimizer, epoch)
        #train(model, device, train_loader, optimizer, swa_optimizer, epoch)
        
        

        # evaluation on natural examples
        #print('================================================================')
        #val_train_loss, val_train_accuracy, val_robust_accuracy = eval_train(model, device, labeled_validation_loader)
        #writer.add_scalar('robust_acc/val', val_robust_accuracy, epoch)
        #writer.add_scalar('clean_acc/val', val_train_accuracy, epoch)
        
        #cnn_swa_test_loss, cnn_swa_test_accuracy, cnn_swa_test_robust_accuracy = eval_test(model, device, test_loader)
        #writer.add_scalar('fly_soft_swa70_robust_acc/test', cnn_swa_test_robust_accuracy, epoch)
       # writer.add_scalar('fly_soft_swa70_clean_acc/test', cnn_swa_test_accuracy, epoch)
        #print('================================================================')

        
       
        # save checkpoint
        #if epoch % save_freq == 0:
        #    torch.save(model.state_dict(),
        #               os.path.join(model_dir, 'fly_soft_swa70_lowBS_50_50_1k_model-cnn-epoch{}.pt'.format(epoch)))
        #    torch.save(optimizer.state_dict(),
        #               os.path.join(model_dir, 'fly_soft_swa70_lowBS_50_50_1k_opt-cnn-checkpoint_epoch{}.tar'.format(epoch)))
    
    #swa_optimizer.swap_swa_sgd()
    #best_swa_test_loss, best_swa_test_accuracy, best_swa_test_robust_accuracy = eval_test(model, device, test_loader)


if __name__ == '__main__':
    main()

Using 4 GPUs!


  0%|          | 0/150 [00:00<?, ?it/s]





2023-10-23 13:25:30,205 | Test: Average loss: 1.2049, Accuracy: 5673/10000 (57%),PDG clean accuracy: 70/128 (54.69%), Robust accuracy 29/128 (22.66%)








2023-10-23 13:34:23,353 | Test: Average loss: 1.0019, Accuracy: 6554/10000 (66%),PDG clean accuracy: 85/128 (66.41%), Robust accuracy 49/128 (38.28%)








2023-10-23 13:43:17,117 | Test: Average loss: 0.9152, Accuracy: 6948/10000 (69%),PDG clean accuracy: 93/128 (72.66%), Robust accuracy 63/128 (49.22%)








2023-10-23 13:52:11,133 | Test: Average loss: 0.8509, Accuracy: 7155/10000 (72%),PDG clean accuracy: 100/128 (78.12%), Robust accuracy 46/128 (35.94%)








2023-10-23 14:01:07,191 | Test: Average loss: 0.7914, Accuracy: 7281/10000 (73%),PDG clean accuracy: 93/128 (72.66%), Robust accuracy 57/128 (44.53%)








2023-10-23 14:10:03,902 | Test: Average loss: 0.7366, Accuracy: 7516/10000 (75%),PDG clean accuracy: 102/128 (79.69%), Robust accuracy 68/128 (53.12%)








2023-10-23 14:19:00,431 | Test: Average loss: 0.7080, Accuracy: 7586/10000 (76%),PDG clean accuracy: 101/128 (78.91%), Robust accuracy 64/128 (50.00%)








2023-10-23 14:27:57,612 | Test: Average loss: 0.7047, Accuracy: 7611/10000 (76%),PDG clean accuracy: 101/128 (78.91%), Robust accuracy 70/128 (54.69%)








2023-10-23 14:36:50,808 | Test: Average loss: 0.6815, Accuracy: 7755/10000 (78%),PDG clean accuracy: 108/128 (84.38%), Robust accuracy 75/128 (58.59%)








2023-10-23 14:45:46,494 | Test: Average loss: 0.6437, Accuracy: 7893/10000 (79%),PDG clean accuracy: 104/128 (81.25%), Robust accuracy 73/128 (57.03%)








2023-10-23 14:54:45,791 | Test: Average loss: 0.6431, Accuracy: 7887/10000 (79%),PDG clean accuracy: 107/128 (83.59%), Robust accuracy 73/128 (57.03%)








2023-10-23 15:03:40,861 | Test: Average loss: 0.6195, Accuracy: 7902/10000 (79%),PDG clean accuracy: 102/128 (79.69%), Robust accuracy 71/128 (55.47%)








2023-10-23 15:12:35,293 | Test: Average loss: 0.5973, Accuracy: 8019/10000 (80%),PDG clean accuracy: 108/128 (84.38%), Robust accuracy 84/128 (65.62%)








2023-10-23 15:21:30,202 | Test: Average loss: 0.6042, Accuracy: 7975/10000 (80%),PDG clean accuracy: 106/128 (82.81%), Robust accuracy 67/128 (52.34%)








2023-10-23 15:30:23,317 | Test: Average loss: 0.5877, Accuracy: 8023/10000 (80%),PDG clean accuracy: 102/128 (79.69%), Robust accuracy 73/128 (57.03%)








2023-10-23 15:39:14,859 | Test: Average loss: 0.5710, Accuracy: 8123/10000 (81%),PDG clean accuracy: 112/128 (87.50%), Robust accuracy 73/128 (57.03%)








2023-10-23 15:48:05,822 | Test: Average loss: 0.5413, Accuracy: 8242/10000 (82%),PDG clean accuracy: 109/128 (85.16%), Robust accuracy 72/128 (56.25%)








2023-10-23 15:56:58,259 | Test: Average loss: 0.5805, Accuracy: 8117/10000 (81%),PDG clean accuracy: 105/128 (82.03%), Robust accuracy 81/128 (63.28%)








2023-10-23 16:05:54,662 | Test: Average loss: 0.5309, Accuracy: 8242/10000 (82%),PDG clean accuracy: 105/128 (82.03%), Robust accuracy 68/128 (53.12%)








2023-10-23 16:14:48,957 | Test: Average loss: 0.5203, Accuracy: 8286/10000 (83%),PDG clean accuracy: 111/128 (86.72%), Robust accuracy 71/128 (55.47%)








2023-10-23 16:23:40,764 | Test: Average loss: 0.5126, Accuracy: 8355/10000 (84%),PDG clean accuracy: 111/128 (86.72%), Robust accuracy 69/128 (53.91%)








2023-10-23 16:32:36,745 | Test: Average loss: 0.5122, Accuracy: 8363/10000 (84%),PDG clean accuracy: 109/128 (85.16%), Robust accuracy 67/128 (52.34%)








2023-10-23 16:41:30,840 | Test: Average loss: 0.5182, Accuracy: 8330/10000 (83%),PDG clean accuracy: 110/128 (85.94%), Robust accuracy 75/128 (58.59%)








2023-10-23 16:50:24,347 | Test: Average loss: 0.4926, Accuracy: 8374/10000 (84%),PDG clean accuracy: 111/128 (86.72%), Robust accuracy 80/128 (62.50%)








2023-10-23 16:59:18,237 | Test: Average loss: 0.5183, Accuracy: 8371/10000 (84%),PDG clean accuracy: 109/128 (85.16%), Robust accuracy 74/128 (57.81%)








2023-10-23 17:08:15,376 | Test: Average loss: 0.4780, Accuracy: 8502/10000 (85%),PDG clean accuracy: 109/128 (85.16%), Robust accuracy 69/128 (53.91%)








2023-10-23 17:17:08,946 | Test: Average loss: 0.4824, Accuracy: 8408/10000 (84%),PDG clean accuracy: 113/128 (88.28%), Robust accuracy 80/128 (62.50%)








2023-10-23 17:26:09,646 | Test: Average loss: 0.4725, Accuracy: 8512/10000 (85%),PDG clean accuracy: 113/128 (88.28%), Robust accuracy 72/128 (56.25%)








2023-10-23 17:35:05,602 | Test: Average loss: 0.4738, Accuracy: 8465/10000 (85%),PDG clean accuracy: 112/128 (87.50%), Robust accuracy 68/128 (53.12%)








2023-10-23 17:44:04,585 | Test: Average loss: 0.4811, Accuracy: 8432/10000 (84%),PDG clean accuracy: 109/128 (85.16%), Robust accuracy 75/128 (58.59%)








2023-10-23 17:52:59,440 | Test: Average loss: 0.4643, Accuracy: 8467/10000 (85%),PDG clean accuracy: 109/128 (85.16%), Robust accuracy 78/128 (60.94%)








2023-10-23 18:01:55,086 | Test: Average loss: 0.4422, Accuracy: 8543/10000 (85%),PDG clean accuracy: 112/128 (87.50%), Robust accuracy 63/128 (49.22%)








2023-10-23 18:10:51,725 | Test: Average loss: 0.4491, Accuracy: 8559/10000 (86%),PDG clean accuracy: 115/128 (89.84%), Robust accuracy 71/128 (55.47%)








2023-10-23 18:19:47,687 | Test: Average loss: 0.4734, Accuracy: 8444/10000 (84%),PDG clean accuracy: 113/128 (88.28%), Robust accuracy 74/128 (57.81%)








2023-10-23 18:28:43,428 | Test: Average loss: 0.4581, Accuracy: 8490/10000 (85%),PDG clean accuracy: 111/128 (86.72%), Robust accuracy 77/128 (60.16%)








2023-10-23 18:37:38,335 | Test: Average loss: 0.4655, Accuracy: 8502/10000 (85%),PDG clean accuracy: 114/128 (89.06%), Robust accuracy 81/128 (63.28%)








2023-10-23 18:46:33,995 | Test: Average loss: 0.4553, Accuracy: 8458/10000 (85%),PDG clean accuracy: 109/128 (85.16%), Robust accuracy 75/128 (58.59%)








2023-10-23 18:55:32,252 | Test: Average loss: 0.4623, Accuracy: 8518/10000 (85%),PDG clean accuracy: 112/128 (87.50%), Robust accuracy 76/128 (59.38%)








2023-10-23 19:04:26,848 | Test: Average loss: 0.4728, Accuracy: 8490/10000 (85%),PDG clean accuracy: 105/128 (82.03%), Robust accuracy 73/128 (57.03%)








2023-10-23 19:13:23,536 | Test: Average loss: 0.4592, Accuracy: 8530/10000 (85%),PDG clean accuracy: 113/128 (88.28%), Robust accuracy 77/128 (60.16%)








2023-10-23 19:22:26,072 | Test: Average loss: 0.4372, Accuracy: 8558/10000 (86%),PDG clean accuracy: 112/128 (87.50%), Robust accuracy 67/128 (52.34%)








2023-10-23 19:31:25,552 | Test: Average loss: 0.4495, Accuracy: 8570/10000 (86%),PDG clean accuracy: 111/128 (86.72%), Robust accuracy 79/128 (61.72%)








2023-10-23 19:40:20,914 | Test: Average loss: 0.4444, Accuracy: 8567/10000 (86%),PDG clean accuracy: 112/128 (87.50%), Robust accuracy 68/128 (53.12%)








2023-10-23 19:49:19,913 | Test: Average loss: 0.4636, Accuracy: 8511/10000 (85%),PDG clean accuracy: 109/128 (85.16%), Robust accuracy 75/128 (58.59%)








2023-10-23 19:58:19,054 | Test: Average loss: 0.4456, Accuracy: 8504/10000 (85%),PDG clean accuracy: 110/128 (85.94%), Robust accuracy 75/128 (58.59%)








2023-10-23 20:07:14,703 | Test: Average loss: 0.4454, Accuracy: 8551/10000 (86%),PDG clean accuracy: 112/128 (87.50%), Robust accuracy 74/128 (57.81%)








2023-10-23 20:16:08,994 | Test: Average loss: 0.4502, Accuracy: 8522/10000 (85%),PDG clean accuracy: 108/128 (84.38%), Robust accuracy 72/128 (56.25%)








2023-10-23 20:25:07,231 | Test: Average loss: 0.4161, Accuracy: 8648/10000 (86%),PDG clean accuracy: 109/128 (85.16%), Robust accuracy 72/128 (56.25%)








2023-10-23 20:34:02,994 | Test: Average loss: 0.4055, Accuracy: 8705/10000 (87%),PDG clean accuracy: 113/128 (88.28%), Robust accuracy 67/128 (52.34%)








2023-10-23 20:42:56,429 | Test: Average loss: 0.4397, Accuracy: 8580/10000 (86%),PDG clean accuracy: 119/128 (92.97%), Robust accuracy 75/128 (58.59%)








2023-10-23 20:51:51,403 | Test: Average loss: 0.4150, Accuracy: 8699/10000 (87%),PDG clean accuracy: 113/128 (88.28%), Robust accuracy 74/128 (57.81%)








2023-10-23 21:00:45,774 | Test: Average loss: 0.4265, Accuracy: 8681/10000 (87%),PDG clean accuracy: 113/128 (88.28%), Robust accuracy 75/128 (58.59%)








2023-10-23 21:09:38,874 | Test: Average loss: 0.4627, Accuracy: 8540/10000 (85%),PDG clean accuracy: 110/128 (85.94%), Robust accuracy 75/128 (58.59%)








2023-10-23 21:18:27,259 | Test: Average loss: 0.5212, Accuracy: 8401/10000 (84%),PDG clean accuracy: 108/128 (84.38%), Robust accuracy 69/128 (53.91%)








2023-10-23 21:27:20,853 | Test: Average loss: 0.4580, Accuracy: 8478/10000 (85%),PDG clean accuracy: 111/128 (86.72%), Robust accuracy 68/128 (53.12%)








2023-10-23 21:36:16,100 | Test: Average loss: 0.4206, Accuracy: 8641/10000 (86%),PDG clean accuracy: 112/128 (87.50%), Robust accuracy 65/128 (50.78%)








2023-10-23 21:45:12,775 | Test: Average loss: 0.4948, Accuracy: 8422/10000 (84%),PDG clean accuracy: 109/128 (85.16%), Robust accuracy 74/128 (57.81%)








2023-10-23 21:54:08,037 | Test: Average loss: 0.4539, Accuracy: 8520/10000 (85%),PDG clean accuracy: 109/128 (85.16%), Robust accuracy 69/128 (53.91%)








2023-10-23 22:03:01,594 | Test: Average loss: 0.4579, Accuracy: 8575/10000 (86%),PDG clean accuracy: 110/128 (85.94%), Robust accuracy 68/128 (53.12%)








2023-10-23 22:11:58,611 | Test: Average loss: 0.4557, Accuracy: 8550/10000 (86%),PDG clean accuracy: 109/128 (85.16%), Robust accuracy 71/128 (55.47%)








2023-10-23 22:20:50,104 | Test: Average loss: 0.4824, Accuracy: 8451/10000 (85%),PDG clean accuracy: 106/128 (82.81%), Robust accuracy 76/128 (59.38%)








2023-10-23 22:29:44,678 | Test: Average loss: 0.4355, Accuracy: 8583/10000 (86%),PDG clean accuracy: 110/128 (85.94%), Robust accuracy 74/128 (57.81%)








2023-10-23 22:38:37,907 | Test: Average loss: 0.4561, Accuracy: 8545/10000 (85%),PDG clean accuracy: 103/128 (80.47%), Robust accuracy 68/128 (53.12%)








2023-10-23 22:47:31,128 | Test: Average loss: 0.4651, Accuracy: 8523/10000 (85%),PDG clean accuracy: 110/128 (85.94%), Robust accuracy 66/128 (51.56%)








2023-10-23 22:56:25,371 | Test: Average loss: 0.4986, Accuracy: 8416/10000 (84%),PDG clean accuracy: 108/128 (84.38%), Robust accuracy 72/128 (56.25%)








2023-10-23 23:05:19,240 | Test: Average loss: 0.5092, Accuracy: 8517/10000 (85%),PDG clean accuracy: 112/128 (87.50%), Robust accuracy 73/128 (57.03%)








2023-10-23 23:14:11,038 | Test: Average loss: 0.4326, Accuracy: 8613/10000 (86%),PDG clean accuracy: 110/128 (85.94%), Robust accuracy 70/128 (54.69%)








2023-10-23 23:23:05,981 | Test: Average loss: 0.4595, Accuracy: 8628/10000 (86%),PDG clean accuracy: 111/128 (86.72%), Robust accuracy 75/128 (58.59%)








2023-10-23 23:31:59,023 | Test: Average loss: 0.4916, Accuracy: 8506/10000 (85%),PDG clean accuracy: 110/128 (85.94%), Robust accuracy 69/128 (53.91%)








2023-10-23 23:40:52,386 | Test: Average loss: 0.4218, Accuracy: 8687/10000 (87%),PDG clean accuracy: 114/128 (89.06%), Robust accuracy 65/128 (50.78%)








2023-10-23 23:50:10,525 | Test: Average loss: 0.4085, Accuracy: 8624/10000 (86%),PDG clean accuracy: 114/128 (89.06%), Robust accuracy 77/128 (60.16%)








2023-10-23 23:59:30,362 | Test: Average loss: 0.4210, Accuracy: 8602/10000 (86%),PDG clean accuracy: 112/128 (87.50%), Robust accuracy 75/128 (58.59%)








2023-10-24 00:08:49,048 | Test: Average loss: 0.4378, Accuracy: 8589/10000 (86%),PDG clean accuracy: 114/128 (89.06%), Robust accuracy 78/128 (60.94%)








2023-10-24 00:18:03,770 | Test: Average loss: 0.4381, Accuracy: 8562/10000 (86%),PDG clean accuracy: 111/128 (86.72%), Robust accuracy 79/128 (61.72%)








2023-10-24 00:27:20,958 | Test: Average loss: 0.4110, Accuracy: 8681/10000 (87%),PDG clean accuracy: 115/128 (89.84%), Robust accuracy 74/128 (57.81%)








2023-10-24 00:36:40,514 | Test: Average loss: 0.3993, Accuracy: 8690/10000 (87%),PDG clean accuracy: 114/128 (89.06%), Robust accuracy 77/128 (60.16%)








2023-10-24 00:45:58,528 | Test: Average loss: 0.4469, Accuracy: 8523/10000 (85%),PDG clean accuracy: 115/128 (89.84%), Robust accuracy 90/128 (70.31%)








2023-10-24 00:55:16,399 | Test: Average loss: 0.4256, Accuracy: 8590/10000 (86%),PDG clean accuracy: 111/128 (86.72%), Robust accuracy 79/128 (61.72%)








2023-10-24 01:04:32,522 | Test: Average loss: 0.4302, Accuracy: 8578/10000 (86%),PDG clean accuracy: 110/128 (85.94%), Robust accuracy 82/128 (64.06%)








2023-10-24 01:13:47,283 | Test: Average loss: 0.4087, Accuracy: 8643/10000 (86%),PDG clean accuracy: 114/128 (89.06%), Robust accuracy 78/128 (60.94%)








2023-10-24 01:23:03,603 | Test: Average loss: 0.4176, Accuracy: 8639/10000 (86%),PDG clean accuracy: 109/128 (85.16%), Robust accuracy 72/128 (56.25%)








2023-10-24 01:32:22,533 | Test: Average loss: 0.4210, Accuracy: 8591/10000 (86%),PDG clean accuracy: 112/128 (87.50%), Robust accuracy 85/128 (66.41%)








2023-10-24 01:41:45,052 | Test: Average loss: 0.4300, Accuracy: 8605/10000 (86%),PDG clean accuracy: 110/128 (85.94%), Robust accuracy 76/128 (59.38%)








2023-10-24 01:51:03,492 | Test: Average loss: 0.4168, Accuracy: 8626/10000 (86%),PDG clean accuracy: 113/128 (88.28%), Robust accuracy 78/128 (60.94%)








2023-10-24 02:00:22,223 | Test: Average loss: 0.4055, Accuracy: 8664/10000 (87%),PDG clean accuracy: 111/128 (86.72%), Robust accuracy 75/128 (58.59%)








2023-10-24 02:09:39,472 | Test: Average loss: 0.4350, Accuracy: 8549/10000 (85%),PDG clean accuracy: 110/128 (85.94%), Robust accuracy 79/128 (61.72%)








2023-10-24 02:18:57,046 | Test: Average loss: 0.3953, Accuracy: 8674/10000 (87%),PDG clean accuracy: 112/128 (87.50%), Robust accuracy 77/128 (60.16%)








2023-10-24 02:28:11,919 | Test: Average loss: 0.4486, Accuracy: 8504/10000 (85%),PDG clean accuracy: 110/128 (85.94%), Robust accuracy 81/128 (63.28%)








2023-10-24 02:37:29,332 | Test: Average loss: 0.4143, Accuracy: 8628/10000 (86%),PDG clean accuracy: 113/128 (88.28%), Robust accuracy 81/128 (63.28%)








2023-10-24 02:46:47,949 | Test: Average loss: 0.4242, Accuracy: 8615/10000 (86%),PDG clean accuracy: 109/128 (85.16%), Robust accuracy 78/128 (60.94%)








2023-10-24 02:56:03,521 | Test: Average loss: 0.3549, Accuracy: 8816/10000 (88%),PDG clean accuracy: 115/128 (89.84%), Robust accuracy 83/128 (64.84%)








2023-10-24 03:05:21,098 | Test: Average loss: 0.3443, Accuracy: 8834/10000 (88%),PDG clean accuracy: 114/128 (89.06%), Robust accuracy 80/128 (62.50%)








2023-10-24 03:14:37,978 | Test: Average loss: 0.3367, Accuracy: 8856/10000 (89%),PDG clean accuracy: 114/128 (89.06%), Robust accuracy 83/128 (64.84%)








2023-10-24 03:23:54,627 | Test: Average loss: 0.3350, Accuracy: 8868/10000 (89%),PDG clean accuracy: 114/128 (89.06%), Robust accuracy 83/128 (64.84%)








2023-10-24 03:33:08,521 | Test: Average loss: 0.3427, Accuracy: 8842/10000 (88%),PDG clean accuracy: 114/128 (89.06%), Robust accuracy 82/128 (64.06%)








2023-10-24 03:42:27,298 | Test: Average loss: 0.3304, Accuracy: 8883/10000 (89%),PDG clean accuracy: 113/128 (88.28%), Robust accuracy 83/128 (64.84%)








2023-10-24 03:51:42,257 | Test: Average loss: 0.3314, Accuracy: 8893/10000 (89%),PDG clean accuracy: 112/128 (87.50%), Robust accuracy 81/128 (63.28%)








2023-10-24 04:00:55,727 | Test: Average loss: 0.3327, Accuracy: 8899/10000 (89%),PDG clean accuracy: 113/128 (88.28%), Robust accuracy 81/128 (63.28%)








2023-10-24 04:10:15,036 | Test: Average loss: 0.3192, Accuracy: 8932/10000 (89%),PDG clean accuracy: 116/128 (90.62%), Robust accuracy 82/128 (64.06%)








2023-10-24 04:19:39,627 | Test: Average loss: 0.3240, Accuracy: 8921/10000 (89%),PDG clean accuracy: 118/128 (92.19%), Robust accuracy 83/128 (64.84%)








2023-10-24 04:28:55,789 | Test: Average loss: 0.3213, Accuracy: 8910/10000 (89%),PDG clean accuracy: 116/128 (90.62%), Robust accuracy 84/128 (65.62%)








2023-10-24 04:38:14,633 | Test: Average loss: 0.3174, Accuracy: 8949/10000 (89%),PDG clean accuracy: 116/128 (90.62%), Robust accuracy 83/128 (64.84%)








2023-10-24 04:47:31,111 | Test: Average loss: 0.3161, Accuracy: 8954/10000 (90%),PDG clean accuracy: 114/128 (89.06%), Robust accuracy 82/128 (64.06%)








2023-10-24 04:56:46,343 | Test: Average loss: 0.3229, Accuracy: 8944/10000 (89%),PDG clean accuracy: 116/128 (90.62%), Robust accuracy 81/128 (63.28%)








2023-10-24 05:05:58,498 | Test: Average loss: 0.3103, Accuracy: 8968/10000 (90%),PDG clean accuracy: 114/128 (89.06%), Robust accuracy 76/128 (59.38%)








2023-10-24 05:15:15,442 | Test: Average loss: 0.3175, Accuracy: 8954/10000 (90%),PDG clean accuracy: 117/128 (91.41%), Robust accuracy 80/128 (62.50%)








2023-10-24 05:24:34,361 | Test: Average loss: 0.3103, Accuracy: 8962/10000 (90%),PDG clean accuracy: 116/128 (90.62%), Robust accuracy 79/128 (61.72%)








2023-10-24 05:33:49,277 | Test: Average loss: 0.3169, Accuracy: 8964/10000 (90%),PDG clean accuracy: 116/128 (90.62%), Robust accuracy 82/128 (64.06%)








2023-10-24 05:43:07,309 | Test: Average loss: 0.3089, Accuracy: 8975/10000 (90%),PDG clean accuracy: 117/128 (91.41%), Robust accuracy 82/128 (64.06%)








2023-10-24 05:52:26,809 | Test: Average loss: 0.3012, Accuracy: 8985/10000 (90%),PDG clean accuracy: 117/128 (91.41%), Robust accuracy 77/128 (60.16%)








2023-10-24 06:01:44,970 | Test: Average loss: 0.3066, Accuracy: 8970/10000 (90%),PDG clean accuracy: 117/128 (91.41%), Robust accuracy 80/128 (62.50%)








2023-10-24 06:11:00,790 | Test: Average loss: 0.3118, Accuracy: 8956/10000 (90%),PDG clean accuracy: 115/128 (89.84%), Robust accuracy 79/128 (61.72%)








2023-10-24 06:20:17,829 | Test: Average loss: 0.3105, Accuracy: 8973/10000 (90%),PDG clean accuracy: 115/128 (89.84%), Robust accuracy 83/128 (64.84%)








2023-10-24 06:29:34,830 | Test: Average loss: 0.3129, Accuracy: 8955/10000 (90%),PDG clean accuracy: 115/128 (89.84%), Robust accuracy 78/128 (60.94%)








2023-10-24 06:38:53,050 | Test: Average loss: 0.3085, Accuracy: 8964/10000 (90%),PDG clean accuracy: 115/128 (89.84%), Robust accuracy 79/128 (61.72%)








2023-10-24 06:48:10,654 | Test: Average loss: 0.3080, Accuracy: 8962/10000 (90%),PDG clean accuracy: 115/128 (89.84%), Robust accuracy 80/128 (62.50%)








2023-10-24 06:57:27,921 | Test: Average loss: 0.3145, Accuracy: 8949/10000 (89%),PDG clean accuracy: 115/128 (89.84%), Robust accuracy 81/128 (63.28%)








2023-10-24 07:06:49,451 | Test: Average loss: 0.3097, Accuracy: 8959/10000 (90%),PDG clean accuracy: 117/128 (91.41%), Robust accuracy 82/128 (64.06%)








2023-10-24 07:16:06,928 | Test: Average loss: 0.3140, Accuracy: 8964/10000 (90%),PDG clean accuracy: 118/128 (92.19%), Robust accuracy 77/128 (60.16%)








2023-10-24 07:25:22,111 | Test: Average loss: 0.3147, Accuracy: 8950/10000 (90%),PDG clean accuracy: 115/128 (89.84%), Robust accuracy 79/128 (61.72%)








2023-10-24 07:34:37,063 | Test: Average loss: 0.3081, Accuracy: 8976/10000 (90%),PDG clean accuracy: 114/128 (89.06%), Robust accuracy 76/128 (59.38%)








2023-10-24 07:43:55,806 | Test: Average loss: 0.3085, Accuracy: 8983/10000 (90%),PDG clean accuracy: 121/128 (94.53%), Robust accuracy 76/128 (59.38%)








2023-10-24 07:53:10,209 | Test: Average loss: 0.2955, Accuracy: 9011/10000 (90%),PDG clean accuracy: 117/128 (91.41%), Robust accuracy 78/128 (60.94%)








2023-10-24 08:02:27,494 | Test: Average loss: 0.3003, Accuracy: 8985/10000 (90%),PDG clean accuracy: 116/128 (90.62%), Robust accuracy 78/128 (60.94%)








2023-10-24 08:11:44,821 | Test: Average loss: 0.3010, Accuracy: 9001/10000 (90%),PDG clean accuracy: 116/128 (90.62%), Robust accuracy 79/128 (61.72%)








2023-10-24 08:21:02,164 | Test: Average loss: 0.3017, Accuracy: 8997/10000 (90%),PDG clean accuracy: 117/128 (91.41%), Robust accuracy 78/128 (60.94%)








2023-10-24 08:30:20,456 | Test: Average loss: 0.2976, Accuracy: 9012/10000 (90%),PDG clean accuracy: 117/128 (91.41%), Robust accuracy 79/128 (61.72%)








2023-10-24 08:39:35,610 | Test: Average loss: 0.3036, Accuracy: 8999/10000 (90%),PDG clean accuracy: 117/128 (91.41%), Robust accuracy 81/128 (63.28%)








2023-10-24 08:48:51,787 | Test: Average loss: 0.2987, Accuracy: 9008/10000 (90%),PDG clean accuracy: 117/128 (91.41%), Robust accuracy 74/128 (57.81%)








2023-10-24 08:58:07,528 | Test: Average loss: 0.3047, Accuracy: 9006/10000 (90%),PDG clean accuracy: 118/128 (92.19%), Robust accuracy 79/128 (61.72%)








2023-10-24 09:07:23,260 | Test: Average loss: 0.2973, Accuracy: 9035/10000 (90%),PDG clean accuracy: 118/128 (92.19%), Robust accuracy 79/128 (61.72%)








2023-10-24 09:16:37,771 | Test: Average loss: 0.2966, Accuracy: 9027/10000 (90%),PDG clean accuracy: 118/128 (92.19%), Robust accuracy 81/128 (63.28%)








2023-10-24 09:25:53,474 | Test: Average loss: 0.2943, Accuracy: 9030/10000 (90%),PDG clean accuracy: 118/128 (92.19%), Robust accuracy 81/128 (63.28%)








2023-10-24 09:35:14,907 | Test: Average loss: 0.2942, Accuracy: 9028/10000 (90%),PDG clean accuracy: 118/128 (92.19%), Robust accuracy 81/128 (63.28%)








2023-10-24 09:44:30,165 | Test: Average loss: 0.2939, Accuracy: 9030/10000 (90%),PDG clean accuracy: 118/128 (92.19%), Robust accuracy 81/128 (63.28%)








2023-10-24 09:53:48,855 | Test: Average loss: 0.2939, Accuracy: 9030/10000 (90%),PDG clean accuracy: 118/128 (92.19%), Robust accuracy 81/128 (63.28%)








2023-10-24 10:03:03,775 | Test: Average loss: 0.2920, Accuracy: 9037/10000 (90%),PDG clean accuracy: 118/128 (92.19%), Robust accuracy 81/128 (63.28%)








2023-10-24 10:12:18,377 | Test: Average loss: 0.2919, Accuracy: 9034/10000 (90%),PDG clean accuracy: 118/128 (92.19%), Robust accuracy 81/128 (63.28%)








2023-10-24 10:21:29,840 | Test: Average loss: 0.2924, Accuracy: 9032/10000 (90%),PDG clean accuracy: 118/128 (92.19%), Robust accuracy 80/128 (62.50%)








2023-10-24 10:30:45,656 | Test: Average loss: 0.2921, Accuracy: 9033/10000 (90%),PDG clean accuracy: 118/128 (92.19%), Robust accuracy 80/128 (62.50%)








2023-10-24 10:40:03,321 | Test: Average loss: 0.2925, Accuracy: 9035/10000 (90%),PDG clean accuracy: 118/128 (92.19%), Robust accuracy 81/128 (63.28%)








2023-10-24 10:49:20,370 | Test: Average loss: 0.2920, Accuracy: 9038/10000 (90%),PDG clean accuracy: 118/128 (92.19%), Robust accuracy 80/128 (62.50%)








2023-10-24 10:58:35,234 | Test: Average loss: 0.2917, Accuracy: 9039/10000 (90%),PDG clean accuracy: 118/128 (92.19%), Robust accuracy 81/128 (63.28%)








2023-10-24 11:07:52,702 | Test: Average loss: 0.2924, Accuracy: 9033/10000 (90%),PDG clean accuracy: 118/128 (92.19%), Robust accuracy 81/128 (63.28%)








2023-10-24 11:17:11,096 | Test: Average loss: 0.2923, Accuracy: 9035/10000 (90%),PDG clean accuracy: 118/128 (92.19%), Robust accuracy 81/128 (63.28%)








2023-10-24 11:26:30,075 | Test: Average loss: 0.2914, Accuracy: 9036/10000 (90%),PDG clean accuracy: 118/128 (92.19%), Robust accuracy 81/128 (63.28%)








2023-10-24 11:35:46,142 | Test: Average loss: 0.2919, Accuracy: 9035/10000 (90%),PDG clean accuracy: 118/128 (92.19%), Robust accuracy 80/128 (62.50%)








2023-10-24 11:45:01,968 | Test: Average loss: 0.2919, Accuracy: 9034/10000 (90%),PDG clean accuracy: 118/128 (92.19%), Robust accuracy 80/128 (62.50%)








2023-10-24 11:54:16,570 | Test: Average loss: 0.2917, Accuracy: 9034/10000 (90%),PDG clean accuracy: 118/128 (92.19%), Robust accuracy 81/128 (63.28%)








2023-10-24 12:03:31,337 | Test: Average loss: 0.2918, Accuracy: 9033/10000 (90%),PDG clean accuracy: 118/128 (92.19%), Robust accuracy 81/128 (63.28%)




2023-10-24 12:03:34,322 | Test: Average loss: 72.3867, Accuracy: 2656/10000 (27%),PDG clean accuracy: 32/128 (25.00%), Robust accuracy 61/128 (47.66%)


In [None]:
 from torchcontrib.optim import SWA
        base_optimizer = optim.SGD(model.parameters(), lr=lr, momentum=args.momentum, weight_decay=args.wd)
        optimizer = SWA(base_optimizer, swa_lr=args.swa_lr)

In [None]:
 # applying swa
    if args.swa == 'True':
        optimizer.swap_swa_sgd()
        optimizer.bn_update(train_loader, model, device)
        if args.validation_exp == "True":
            loss_swa, acc_val_swa = validating(args, model, device, test_loader)

In [None]:
if args.swa == 'True':
        if epoch > args.swa_start and epoch%args.swa_freq == 0 :
            optimizer.update_swa()

def main():
    # init model, ResNet18() can be also used here for training
    #os.environ["CUDA_VISIBLE_DEVICES"] = "1"
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    num_classes = 10
    model = get_model('wrn-28-10', num_classes=num_classes,
                          normalize_input=False)
    model = model.to(device)

    if torch.cuda.device_count() > 1:
        print("Using", torch.cuda.device_count(), "GPUs!")
        model = DataParallel(model)
    #model = model.to(device)
    lr =0.1
    print_freq = 50
    optimizer = optim.SGD(model.parameters(), lr=lr, momentum=momentum, weight_decay=weight_decay)
    criterion = nn.CrossEntropyLoss()
    save_freq = 50
    
    
    model.train()
    optimizer.zero_grad()
    
    # Forward pass for labeled data
    for x_labeled, y_labeled in labeled_loader:
        x_labeled, y_labeled = x_labeled.to(device), y_labeled.to(device)
        labeled_output = model(x_labeled)
        labeled_loss = criterion(labeled_output, y_labeled)

        # Backward pass and optimization
        labeled_loss.backward()
        optimizer.step()
        optimizer.zero_grad()

    writer = SummaryWriter("runs/4k_adversarial")
    for epoch in tqdm(range(1, epochs + 1)):
        lr = adjust_learning_rate(optimizer, epoch)
        train_new(model, device, labeled_loader, unlabeled_loader, optimizer, epoch)

        # evaluation on natural examples
        print('================================================================')
        val_train_loss, val_train_accuracy, val_robust_accuracy = eval_train(model, device, labeled_validation_loader)
        writer.add_scalar('robust_acc/val', val_robust_accuracy, epoch)
        writer.add_scalar('clean_acc/val', val_train_accuracy, epoch)
        
        test_loss, test_accuracy, test_robust_accuracy = eval_test(model, device, test_loader)
        writer.add_scalar('robust_acc/test', test_robust_accuracy, epoch)
        writer.add_scalar('clean_acc/test', test_accuracy, epoch)
        print('================================================================')

        
       
        # save checkpoint
        if epoch % save_freq == 0:
            torch.save(model.state_dict(),
                       os.path.join(model_dir, '1_4k_model-wideres-epoch{}.pt'.format(epoch)))
            torch.save(optimizer.state_dict(),
                       os.path.join(model_dir, '1_4k_opt-wideres-checkpoint_epoch{}.tar'.format(epoch)))


if __name__ == '__main__':
    main()

In [None]:
model.train()
optimizer.zero_grad()

writer = SummaryWriter("runs/10k_adversarial")

for epoch in tqdm(range(1, epochs + 1)):
    # adjust learning rate for SGD
    adjust_learning_rate(optimizer, epoch, lr)
    
    # Forward pass for labeled data
    for x_labeled, y_labeled in labeled_loader:
        x_labeled, y_labeled = x_labeled.to(device), y_labeled.to(device)
        labeled_output = model(x_labeled)
        labeled_loss = criterion(labeled_output, y_labeled)

        # Backward pass and optimization
        labeled_loss.backward()
        optimizer.step()
        optimizer.zero_grad()

    # Generate pseudolabels for unlabeled data batch by batch
    model.eval()
    for x_unlabeled, _ in unlabeled_loader:
        data_array = x_unlabeled.to(device)
        unlabeled_output = model(data_array)
        pseudo_labels_batch = unlabeled_output.argmax(dim=1)

        # Convert data_array and pseudo_labels_batch to tensors
        data_array = data_array.cpu().numpy()
        pseudo_labels_batch = pseudo_labels_batch.cpu().numpy()

        # Perform training with labeled and pseudo-labeled batch
        combined_dataset = torch.utils.data.TensorDataset(torch.from_numpy(data_array),
                                                         torch.from_numpy(pseudo_labels_batch))
        combined_loader = torch.utils.data.DataLoader(combined_dataset,
                                                      batch_size=TRAIN_BS,
                                                      shuffle=True,
                                                      num_workers=2,
                                                      collate_fn=my_collate)

        train(model, device, combined_loader, optimizer, epoch)

    # evaluation on natural examples
    print('================================================================')
    train_loss, training_accuracy = eval_train(model, device, labeled_loader)
    eval_test(model, device, test_loader)
    print('================================================================')


In [None]:
def main():
    # init model, ResNet18() can be also used here for training
    #os.environ["CUDA_VISIBLE_DEVICES"] = "1"
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    num_classes = 10
    model = get_model('wrn-28-10', num_classes=num_classes,
                          normalize_input=False)
    model = model.to(device)

    if torch.cuda.device_count() > 1:
        print("Using", torch.cuda.device_count(), "GPUs!")
        model = DataParallel(model)
    #model = model.to(device)
    lr =0.1
    optimizer = optim.SGD(model.parameters(), lr=lr, momentum=momentum, weight_decay=weight_decay)
    criterion = nn.CrossEntropyLoss()
    save_freq = 50
    # Load the combined_loader
    #with open('./data/combined_loader.pkl', 'rb') as f:
    #    combined_loader = pickle.load(f)
    
    model.train()
    optimizer.zero_grad()

    # Forward pass
    for x_labeled, y_labeled in labeled_loader:
        x_labeled, y_labeled = x_labeled.to(device), y_labeled.to(device)
        labeled_output = model(x_labeled)
        labeled_loss = criterion(labeled_output, y_labeled)

        # Backward pass and optimization
        labeled_loss.backward()
        optimizer.step()
        optimizer.zero_grad()
    
    
    for epoch in tqdm(range(1, epochs + 1)):
        # adjust learning rate for SGD
        adjust_learning_rate(optimizer, epoch, lr)
        

        # Generate pseudo-labels for unlabeled data
        model.eval()
        pseudo_labels = []
        data_array = []
        with torch.no_grad():
            for x_unlabeled, _ in unlabeled_loader:
                data_array.append(x_unlabeled.numpy())
                x_unlabeled = x_unlabeled.to(device)
                
                unlabeled_output = model(x_unlabeled)
                pseudo_labels_batch = unlabeled_output.argmax(dim=1)
                pseudo_labels.append(pseudo_labels_batch.cpu().numpy())
        
        
        data_array = np.concatenate(data_array)
        new_extrapolated_targets = np.concatenate(pseudo_labels)
        
        # Convert data_array and new_extrapolated_targets to tensors
        data_array = torch.from_numpy(data_array)
        new_extrapolated_targets = torch.from_numpy(new_extrapolated_targets)
        
        
        # Create a new dataset by combining labeled data and predicted labels
        pseudolabel_combined_dataset = torch.utils.data.TensorDataset(data_array, new_extrapolated_targets)

        # Create a combined dataloader
        pseudolabel_combined_loader = torch.utils.data.DataLoader(pseudolabel_combined_dataset, batch_size=TRAIN_BS, shuffle=True, num_workers=2)
        # Combine labeled and pseudo-labeled data
        # Create a new pseudo-labels dataset
        combined_dataset = torch.utils.data.ConcatDataset([labeled_dataset, pseudolabel_combined_dataset])
        combined_loader = torch.utils.data.DataLoader(
            combined_dataset, batch_size=TRAIN_BS, shuffle=True, num_workers=2, collate_fn=my_collate)

        
 
        
        
        # adversarial training
        train(model, device, combined_loader, optimizer, epoch)

        # evaluation on natural examples
        print('================================================================')
        train_loss, training_accuracy = eval_train(model, device, labeled_loader)
        eval_test(model, device, test_loader)
        print('================================================================')

        
        writer.add_scalar('robust_accuracy', np.array(training_accuracy), epoch)
        # save checkpoint
        if epoch % save_freq == 0:
            torch.save(model.state_dict(),
                       os.path.join(model_dir, '10k_model-wideres-epoch{}.pt'.format(epoch)))
            torch.save(optimizer.state_dict(),
                       os.path.join(model_dir, '10k_opt-wideres-checkpoint_epoch{}.tar'.format(epoch)))


if __name__ == '__main__':
    main()





In [None]:
len(unlabeled_dataset)

In [None]:
len(c_dataset)

In [None]:
def train(model, device, train_loader, optimizer, epoch):
    model.train()
    for batch_idx, (data, target) in enumerate(train_loader):
        data, target = data.to(device), target.to(device)

        optimizer.zero_grad()

        # calculate robust loss
        loss = trades_loss(model=model,
                           x_natural=data,
                           y=target,
                           optimizer=optimizer,
                           step_size=step_size,
                           epsilon=epsilon,
                           perturb_steps=num_steps,
                           beta=beta)
        loss.backward()
        optimizer.step()

        # print progress
        if batch_idx % log_interval == 0:
            print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format(
                epoch, batch_idx * len(data), len(train_loader.dataset),
                       100. * batch_idx / len(train_loader), loss.item()))

In [None]:
class SemiSupervisedSampler(Sampler):
    """Balanced sampling from the labeled and unlabeled data"""
    def __init__(self, sup_inds, unsup_inds, batch_size, unsup_fraction=0.5,
                 num_batches=None):
        if unsup_fraction is None or unsup_fraction < 0:
            self.sup_inds = sup_inds + unsup_inds
            unsup_fraction = 0.0
        else:
            self.sup_inds = sup_inds
            self.unsup_inds = unsup_inds

        self.batch_size = batch_size
        unsup_batch_size = int(batch_size * unsup_fraction)
        self.sup_batch_size = batch_size - unsup_batch_size

        if num_batches is not None:
            self.num_batches = num_batches
        else:
            self.num_batches = int(
                np.ceil(len(self.sup_inds) / self.sup_batch_size))

        super().__init__(None)

    def __iter__(self):
        batch_counter = 0
        while batch_counter < self.num_batches:
            sup_inds_shuffled = [self.sup_inds[i]
                                 for i in torch.randperm(len(self.sup_inds))]
            for sup_k in range(0, len(self.sup_inds), self.sup_batch_size):
                if batch_counter == self.num_batches:
                    break
                batch = sup_inds_shuffled[sup_k:(sup_k + self.sup_batch_size)]
                if self.sup_batch_size < self.batch_size:
                    batch.extend([self.unsup_inds[i] for i in
                                  torch.randint(high=len(self.unsup_inds),
                                                size=(
                                                    self.batch_size - len(
                                                        batch),),
                                                dtype=torch.int64)])
                # this shuffle operation is very important, without it
                # batch-norm / DataParallel hell ensues
                np.random.shuffle(batch)
                yield batch
                batch_counter += 1

    def __len__(self):
        return self.num_batches


In [None]:
sup_inds = [1, 2, 3, 4, 5]  # Labeled data indices
unsup_inds = [6, 7, 8, 9, 10]  # Unlabeled data indices
batch_size = 4
unsup_fraction = 0.5

sampler = SemiSupervisedSampler(sup_inds, unsup_inds, batch_size, unsup_fraction)
dataloader = torch.utils.data.DataLoader(dataset, sampler=sampler, batch_size=batch_size)

for batch_indices in dataloader:
    # Use batch_indices for training
    print(batch_indices)

In [None]:
for batch_data, batch_labels in combined_loader:
    print("Batch data shape:", batch_data.shape)
    print("Batch labels:", batch_labels)


In [None]:
def train_new(model, device, labeled_loader, unlabeled_loader, optimizer, swa_optimizer, epoch):    

    labeled_iter = iter(labeled_loader)    
    num_iter = (len(unlabeled_loader.dataset)//TRAIN_BS)+1
    for batch_idx, (x_unlabeled, _) in enumerate(unlabeled_loader):      
        try:
            x_labeled, y_labeled = next(labeled_iter)
        except StopIteration:
            labeled_iter = iter(labeled_loader)
            x_labeled, y_labeled = next(labeled_iter) 
        model.eval()
        with torch.no_grad():
            x_unlabeled = x_unlabeled.to(device)
            
            logits_unlabeled = model(x_unlabeled.to(device))
            # Softmax normalization to obtain probabilities for each class
            pseudo_labels = F.softmax(logits_unlabeled, dim=1)
            
            
            
        
        model.train()
        
        # concatenate (x_labeled, y_labeled) and (x_unlabeled, pseudo_labels)
        x_labeled = x_labeled.to(device)
        y_labeled = y_labeled.to(device)
        #print(y_labeled.shape)
        #print(y_labeled[0])
        # Add a new dimension to y_labeled to make it compatible for concatenation
        num_classes = 10
        y_labeled_onehot = F.one_hot(y_labeled, num_classes)
        #print(y_labeled_onehot.shape)
        #print(y_labeled_onehot[0])
        pseudo_labels = pseudo_labels.to(device)
        #print(pseudo_labels.shape)
        
        data = torch.cat([x_labeled,x_unlabeled],dim=0)
        target = torch.cat([y_labeled,pseudo_labels],dim=0)
        
        optimizer.zero_grad()
        # calculate robust loss
        loss = trades_loss(model=model,
                           x_natural=data,
                           y=target,
                           optimizer=optimizer,
                           step_size=0.007,
                           epsilon=0.031,
                           perturb_steps=10,
                           beta=beta)
        loss.backward()
        optimizer.step()


In [None]:
""""
for images,labels in combined_loader:
         print(type(images))
         print(type(labels))
         print(images.size())
         print(images.dim())
""""

In [None]:
""""
model_dir= 'rst_adv/pseudolabel_cifar10/10k'
epoch = 100
normalize_input = False
checkpoint = torch.load(os.path.join(model_dir, 'checkpoint-epoch%d.pt' % epoch))
num_classes = checkpoint.get('num_classes', 10)
normalize_input = checkpoint.get('normalize_input', False)
model = get_model('wrn-28-10', 
                  num_classes=num_classes,
                  normalize_input=normalize_input)
model = model.to(device)
#model = nn.DataParallel(model).cuda()
model.load_state_dict(checkpoint['state_dict'])
"""

In [None]:
"""predictions = []
data_array = []
for i, (batch, _) in enumerate(unlabeled_loader):
    data_array.append(batch.numpy())
    _, preds = torch.max(model(batch.to(device)), dim=1)
    predictions.append(preds.cpu().numpy())
    
    if (i+1) % 10 == 0:
        print('Done %d/%d' % (i+1, len(unlabeled_loader)))
data_array = np.concatenate(data_array)
new_extrapolated_targets = np.concatenate(predictions)

# Convert data_array and new_extrapolated_targets to tensors
data_array = torch.from_numpy(data_array)
new_extrapolated_targets = torch.from_numpy(new_extrapolated_targets)

# Create a new dataset by combining labeled data and predicted labels
pseudolabel_combined_dataset = torch.utils.data.TensorDataset(data_array, new_extrapolated_targets)

# Create a combined dataloader
pseudolabel_combined_loader = torch.utils.data.DataLoader(pseudolabel_combined_dataset, batch_size=TRAIN_BS, shuffle=True, num_workers=2)

def my_collate(batch):
    """Define collate_fn myself because the default_collate_fn throws errors like crazy"""
    # item: a tuple of (img, label)
    data = [item[0] for item in batch]
    target = [item[1] for item in batch]
    data = torch.stack(data)
    target = torch.LongTensor(target)
    return [data, target]

combined_dataset = torch.utils.data.ConcatDataset([labeled_dataset, pseudolabel_combined_dataset])
combined_loader = torch.utils.data.DataLoader(
    combined_dataset, batch_size=TRAIN_BS, shuffle=True, num_workers=2, collate_fn=my_collate)


# Save the combined_loader
with open('./data/combined_loader.pkl', 'wb') as f:
    pickle.dump(combined_loader, f)
    """

In [None]:
import torch
import torch.nn.functional as F

# Example data
x_labeled = torch.randn((5, 3, 32, 32))  # Example labeled data
y_labeled = torch.tensor([1, 0, 2, 1, 0])  # Example class indices for labeled data

x_unlabeled = torch.randn((5, 3, 32, 32))  # Example unlabeled data

# Example model
class YourModel(torch.nn.Module):
    def __init__(self):
        super(YourModel, self).__init__()
        self.fc = torch.nn.Linear(3 * 32 * 32, 10)

    def forward(self, x):
        x = x.view(x.size(0), -1)
        return self.fc(x)

model = YourModel()

# Set the model to evaluation mode for pseudo-label generation
model.eval()

# Generate pseudo-labels
with torch.no_grad():
    logits_unlabeled = model(x_unlabeled)
    pseudo_labels = F.softmax(logits_unlabeled, dim=1)

# Ensure y_labeled has the same number of classes as the model
num_classes = 10
y_labeled = y_labeled.view(-1, 1).long()
y_labeled_onehot = F.one_hot(y_labeled, num_classes).float()  # Convert to one-hot encoding

# Concatenate class indices and softmax probabilities
target = torch.cat([y_labeled_onehot, pseudo_labels], dim=0)

print("Pseudo Labels:")
print(pseudo_labels)
print("Target (Concatenated):")
print(target)


In [None]:
import torch
import torch.nn.functional as F

# Example hard labels (class indices)
hard_labels = torch.tensor([1, 0, 2, 1, 0])  # Example shape: (batch_size,)

# Example number of classes
num_classes = 10

# Convert hard labels to one-hot encoding
hard_labels_onehot = F.one_hot(hard_labels, num_classes).float()

# Example soft label probabilities
soft_labels = F.softmax(torch.randn(5, 10), dim=1)  # Example shape: (batch_size, num_classes)

# Concatenate hard labels in one-hot encoding and soft labels
concatenated_labels = torch.cat([hard_labels_onehot, soft_labels], dim=0)

print("Hard Labels (One-Hot):")
print(hard_labels_onehot)
print("Soft Labels:")
print(soft_labels)
print("Concatenated Labels:")
print(concatenated_labels)
