In [1]:
import torch
import torch.nn as nn
import numpy as np
import matplotlib.pyplot as plt
import torchvision.datasets as datasets
import torchvision.transforms as transforms
from torchmetrics import Accuracy
from torch.utils.data import Dataset, DataLoader
from tqdm.auto import trange, tqdm
import os

In [2]:
ds = datasets.CIFAR10("",download=False)
m=((ds.data/255).mean((0,1,2)))
s = ((ds.data/255).std((0,1,2)))
transform = transforms.Compose([transforms.ToTensor(),
                                transforms.Normalize(m,s)])
train_set = datasets.CIFAR10(".",download=False,
                             transform=transform,target_transform=torch.tensor)
val_set = datasets.CIFAR10(".",download=False,transform=transform,
                           train=False,target_transform=torch.tensor)
train_loader = DataLoader(train_set,batch_size=128,shuffle=True)
val_loader = DataLoader(val_set,batch_size=128,shuffle=False)

In [3]:
class SimpleAE(nn.Module):
    def __init__(self):
        super(SimpleAE, self).__init__()
        self.encoder = nn.Sequential(
                                    nn.Conv2d(3,64, (3, 3), padding=(1, 1)),
                                    nn.ReLU(),
                                    nn.BatchNorm2d(64),
                                    nn.Conv2d(64,64, (3, 3), padding=(1, 1)),
                                    
   )
        self.decoder = nn.Sequential(nn.ConvTranspose2d(64,64,(3,3),padding=(1,1))
                                    ,nn.ReLU(),
                                    nn.BatchNorm2d(64),
                                     nn.ConvTranspose2d(64,3, (3, 3), padding=(1, 1)))
    def forward(self,X):
        return self.decoder(self.encoder(X))
    

In [4]:
model = SimpleAE()
model = model.cuda()
loss_fn = nn.MSELoss()
optimizer = torch.optim.AdamW(model.parameters(),1e-3)
scheduler = torch.optim.lr_scheduler.ExponentialLR(optimizer, gamma=0.9,verbose=True)
train_loss = []
val_loss = []
epochs = 50

Adjusting learning rate of group 0 to 1.0000e-03.


In [5]:
if not os.path.isfile("model.pt"):
    outer_loop = trange(epochs)
    for epoch in outer_loop:
        r_loss = 0
        for X,y in tqdm(train_loader,leave = False):
            X = X.cuda()
            out = model(X)
            optimizer.zero_grad()
            loss = loss_fn(out,X)
            loss.backward()
            optimizer.step()
            r_loss += loss.item()
        train_loss.append(r_loss/len(train_loader))
        with torch.no_grad():
            r_loss = 0
            for X,y in tqdm(val_loader,leave = False):
                X = X.cuda()
                out = model(X)
                loss = loss_fn(out,X)
                r_loss += loss.item()
        val_loss.append(r_loss/len(val_loader))
        scheduler.step()
        outer_loop.set_postfix({"train loss":train_loss[-1],
                               "val loss":val_loss[-1]})
    plt.plot(train_loss)
    plt.plot(val_loss)
    torch.save(model.state_dict(),"model.pt")
else:
    model.load_state_dict(torch.load("model.pt"))

In [6]:
with torch.no_grad():
    outs = torch.zeros((10000,1600),device="cuda")
    ys = torch.zeros(10000,device="cuda")
    for i,(X,y) in tqdm(enumerate(val_loader),total = len(val_loader)):
        out = model.encoder(X.cuda())
        out = nn.functional.max_pool2d(out,(10,10),(5,5),padding = 0)
        out = torch.flatten(out,start_dim=1)
        outs[128*i:128*(i+1),:]=out
        ys[128*i:128*(i+1)] = y

  0%|          | 0/79 [00:00<?, ?it/s]

In [7]:
def SI_evaluation(outs,ys):
    with torch.no_grad():
        assert ys.get_device() != -1
        assert outs.get_device() != -1
        ys = ys
        m = outs @outs.T
        d = torch.diag(m)
        d = d.reshape(-1,1)
        w = torch.tile(d,(1,outs.shape[0]))
        D = w+w.T -2*m
        inf = torch.max(D)*100
        I = inf*torch.eye(D.shape[0]).cuda()
        D = D+I
        labs = torch.argmin(D,dim=1)
        labs = labs.detach().cpu().numpy()
        labs = ys[labs]
        return (labs == ys).sum()/len(labs)

In [8]:
# SI_evaluation(outs,ys)

In [9]:
def forward_selection(X,y):
    
    selected = []
    best_SIs = []
    done = False
    prv_SI = 0
    while not done:
        SI_list = []
        idx_list = []
        for i in trange(X.shape[1],leave=False):
            idx = selected.copy()
            if not i in idx:
                idx.append(i)
                features = X[:,idx]
                SI = SI_evaluation(features,y)
                SI_list.append(SI.item())
                idx_list.append(i)
        
        best_idx = np.argmax(SI_list)
        best_SI = SI_list[best_idx]
        best_feature = idx_list[best_idx]
        print(f"best SI:{best_SI} for {best_feature}")
        if best_SI<= prv_SI:
            print("best features obtaied")
            done = True
        else:
            selected.append(best_feature)
            best_SIs.append(best_SI)
            prv_SI = best_SI
    return selected,best_SIs

In [10]:
selected,best_SIs = forward_selection(outs,ys)

  0%|          | 0/1600 [00:00<?, ?it/s]

KeyboardInterrupt: 

In [None]:
plt.plot(best_SIs)
plt.grid("on")
plt.xlabel("number of features")
plt.ylabel("SI")
plt.title("SI per selected features")

In [None]:
ys.device == torch.device("cuda:0")

In [None]:
torch.device("cuda:1")

In [None]:
ys.get_device()

In [None]:
a = torch.rand(10)

In [None]:
a.get_device()