In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision.transforms.functional as f
import torchvision.transforms as transforms
import torchvision
from torch.utils.data import DataLoader,Dataset,random_split
import PIL
import random
if (torch.cuda.is_available()):
  device = torch.device('cuda')
else:
  device = torch.device('cpu')

In [None]:
class Block(nn.Module):
  def __init__(self,in_channels):
    super().__init__()
    self.m1 = nn.Sequential(
                             nn.Conv2d(in_channels,in_channels,3,padding = 1),
                             nn.ReLU(),
                             nn.Conv2d(in_channels,in_channels,3,padding = 1),
                             nn.ReLU(),
                             nn.Conv2d(in_channels,in_channels*2,3,padding = 1,stride = 2),
                             nn.ReLU()
    )
  
  def forward(self,x):
    return self.m1(x)

class Encode(nn.Module):
  def __init__(self,in_channels):
    super().__init__()
    self.e = nn.Sequential(
                              nn.Conv2d(3,in_channels,3,padding=1),
                              nn.ReLU(),
                              Block(in_channels),
                              Block(in_channels*2),
                              Block(in_channels*4),                     
                              Block(in_channels*8),
                              Block(in_channels*16)
      )
    self.e2 = nn.Sequential(
                              nn.Linear(9216,2048),
                              nn.ReLU(),
                              nn.Linear(2048,1024)
                              )
  def forward(self,x,flag = True):
    x = self.e(x)
    x = torch.reshape(x,(x.shape[0],-1,1,1))
    x = torch.squeeze(x,dim =2)
    x = torch.squeeze(x,dim =2)
    if flag:
      x = self.e2(x)
    return x

class Network(nn.Module):
   def __init__(self,in_channels):
    super().__init__()
    self.e = Encode(in_channels)
    self.c = nn.Sequential(
                            nn.Linear(9216,10),
                            nn.Softmax()                     
    )

  
   def forward(self,x,encode = True,classify = True):
    if encode:
      x = self.e(x,True)      
    if classify:
      x = self.e(x,False)
      x = self.c(x)
    return x
model = Network(32)
model = model.to(device)

In [None]:
class Unsoupdata(Dataset):
  def __init__(self):
    super().__init__()
    transform = transforms.ToTensor()
    self.t = transforms.Compose([
                            transforms.RandomVerticalFlip(0.5),
                            transforms.RandomHorizontalFlip(0.5),
                            transforms.RandomApply([transforms.ColorJitter()],p=0.5)
    ])
    self.data = torchvision.datasets.STL10(root = '/',split = 'unlabeled',download = True,transform = transform)
    self.n = len(self.data)
  def __len__(self):
    return len(self.data)
  def __getitem__(self,i):
    decide = random.randint(0,self.n-1)
    while (decide==i):    
      decide = random.randint(0,self.n-1)
    return self.data[i][0],self.t(self.data[i][0]),self.data[decide][0],self.t(self.data[decide][0])

class Soupdata(Dataset):
  def __init__(self):
    super().__init__()
    transform = transforms.ToTensor()
    self.data = torchvision.datasets.STL10(root = '/',split = 'train',download = True,transform = transform)
    self.n = len(self.data)
  def __len__(self):
    return len(self.data)
  
  def __getitem__(self,i):
    return self.data[i]

In [None]:
batch_size = 128
unsoup = Unsoupdata()
unsouploader = DataLoader(dataset=unsoup,batch_size = 128,shuffle = True)
soup =  Soupdata()
train_data,test_val = random_split(soup,[len(soup)-500,500])
test_data = torchvision.datasets.STL10(root = '/',split = 'test',download = True,transform = transforms.ToTensor())
souptrain = DataLoader(dataset=train_data,batch_size = batch_size,shuffle = True)
soupval = DataLoader(dataset=test_val,batch_size = len(test_val),shuffle = True)
souptest = DataLoader(dataset=test_data,batch_size = len(test_data),shuffle = True)

Downloading http://ai.stanford.edu/~acoates/stl10/stl10_binary.tar.gz to /stl10_binary.tar.gz


HBox(children=(FloatProgress(value=0.0, max=2640397119.0), HTML(value='')))


Extracting /stl10_binary.tar.gz to /
Files already downloaded and verified
Files already downloaded and verified


In [None]:
coste = nn.TripletMarginLoss()
costc = nn.CrossEntropyLoss()
sim = []
def loss(intake,t = 1,eps = 1e-8):
  siml = nn.CosineSimilarity(dim = 0)
  positive = siml(intake[0],intake[1])*siml(intake[2],intake[3])
  negative = siml(intake[0],intake[2])+siml(intake[0],intake[3])+siml(intake[1],intake[2])+siml(intake[1],intake[3])
  negative = negative**2
  positive = torch.exp(positive)/t
  negative = torch.exp(negative)/t
  los = -1*torch.log(positive/negative)
  los = los.mean()
  return los

In [None]:
optimizere = torch.optim.Adam(model.e.parameters(),lr = 0.001)
optimizerc = torch.optim.Adam(model.c.parameters(),lr = 0.001)
optimizerf = torch.optim.Adam(model.parameters(),lr = 0.001)
def test(data):
    c = 0
    s = 0
    for i,(x,y) in enumerate(data):
        with torch.no_grad():
            x =x.to(device)
            y = y.to(device,dtype = torch.int64)
            yt = model(x,False,True)
            yt = torch.argmax(yt, dim= 1)
            c = (y == yt).sum()
            s = y.shape[0]
        break
    return (100*c/s).item()

In [None]:
epochs = 100
for j in range(epochs):
  for i,(x1,x2,x3,x4) in enumerate(unsouploader) :
    x1 = x1.to(device)
    x2 = x2.to(device)
    x3 = x3.to(device)
    x4 = x4.to(device)
    y1 = model(x1,True,False)
    y2 = model(x2,True,False)
    y3 = model(x3,True,False)
    y4 = model(x4,True,False)
    y1 = torch.unsqueeze(y1,dim = 0)
    y2 = torch.unsqueeze(y2,dim = 0)
    y3 = torch.unsqueeze(y3,dim = 0)
    y4 = torch.unsqueeze(y4,dim = 0)
    a = torch.cat([y1,y2,y3,y4],dim = 0)
    losse = loss(a)          
    optimizere.zero_grad()
    losse.backward()
    optimizere.step()
    print(f'epoch {j+1} step {i} loss {losse}')


epoch 1 step 0 loss 15.0
epoch 1 step 1 loss 14.999998092651367
epoch 1 step 2 loss 15.0
epoch 1 step 3 loss 14.999999046325684
epoch 1 step 4 loss 14.990612983703613
epoch 1 step 5 loss 15.0
epoch 1 step 6 loss 14.999999046325684
epoch 1 step 7 loss 14.999988555908203
epoch 1 step 8 loss 14.999225616455078
epoch 1 step 9 loss 14.998809814453125
epoch 1 step 10 loss 14.949789047241211
epoch 1 step 11 loss 14.989103317260742
epoch 1 step 12 loss 14.999991416931152
epoch 1 step 13 loss 14.999320983886719
epoch 1 step 14 loss 14.999996185302734
epoch 1 step 15 loss 14.999985694885254
epoch 1 step 16 loss 14.99995231628418
epoch 1 step 17 loss 14.9996337890625
epoch 1 step 18 loss 14.999763488769531
epoch 1 step 19 loss 14.999542236328125
epoch 1 step 20 loss 14.99697494506836
epoch 1 step 21 loss 14.975410461425781
epoch 1 step 22 loss 14.979767799377441
epoch 1 step 23 loss 14.797109603881836
epoch 1 step 24 loss 12.336801528930664
epoch 1 step 25 loss 7.178368091583252
epoch 1 step 26 l

KeyboardInterrupt: ignored

In [None]:
epochs = 40
for j in range(epochs):
   for i,(xt,yt) in enumerate(souptrain) :
    xt = xt.to(device)
    #xt = torch.reshape(xt,[xt.shape[0],-1,1,1])
    yt = yt.to(device)
    #xt = torch.squeeze(xt,dim = 2)
    #xt = torch.squeeze(xt,dim = 2)
    y_pred = model(xt,False,True)
    y_pred = torch.squeeze(y_pred)
    lossc = costc(y_pred,yt)          
    optimizerc.zero_grad()
    optimizere.zero_grad()
    lossc.backward()
    optimizerc.step()
    if (i%20==0):
      acc = test(soupval)
      print(f'epoch {j+1} step {i} loss {lossc} test_accuracy {acc} train_accuracy {test(souptrain)}')
    else:
      print(f'epoch {j+1} step {i} loss {lossc}')

  input = module(input)


epoch 1 step 0 loss 2.312713384628296 test_accuracy 11.200000762939453 train_accuracy 8.59375
epoch 1 step 1 loss 2.359588384628296
epoch 1 step 2 loss 2.359588384628296
epoch 1 step 3 loss 2.375213384628296
epoch 1 step 4 loss 2.383025884628296
epoch 1 step 5 loss 2.375213384628296
epoch 1 step 6 loss 2.328338384628296
epoch 1 step 7 loss 2.343963384628296
epoch 1 step 8 loss 2.406463384628296
epoch 1 step 9 loss 2.328338384628296
epoch 1 step 10 loss 2.343963384628296
epoch 1 step 11 loss 2.406463384628296
epoch 1 step 12 loss 2.359588384628296
epoch 1 step 13 loss 2.383025884628296
epoch 1 step 14 loss 2.390838384628296
epoch 1 step 15 loss 2.328338384628296
epoch 1 step 16 loss 2.343963384628296
epoch 1 step 17 loss 2.406463384628296
epoch 1 step 18 loss 2.383025884628296
epoch 1 step 19 loss 2.375213384628296
epoch 1 step 20 loss 2.351775884628296 test_accuracy 11.200000762939453 train_accuracy 10.15625
epoch 1 step 21 loss 2.343963384628296
epoch 1 step 22 loss 2.383025884628296


KeyboardInterrupt: ignored

In [None]:
test(souptest)

RuntimeError: ignored