In [1]:
import math
import time

import torch 
import torch.nn as nn
import torch.nn.functional as F

import numpy as np

In [2]:

import matplotlib.pyplot as plt
from collections import OrderedDict

from sklearn.model_selection import train_test_split
from sklearn.preprocessing import RobustScaler, StandardScaler
from sklearn.metrics import confusion_matrix, accuracy_score, f1_score

In [3]:
from utils.inference import Trainer, plot_loss
from utils.models import DNN
import utils.datasets as d
import utils.layers as layers

In [4]:
from IPython.core.debugger import set_trace

In [5]:
TMP_X = None
TMP_y= None

# Potřebné funkce

In [26]:
class One_Hot(object):
    def __init__(self, n_classes):
        self.n_classes = n_classes
        self.class_matrix = torch.diag(torch.ones(n_classes))

    def __call__(self, p):
        return self.class_matrix[p]

In [27]:
def gaussian_nll(y_true, mu, sigma, reduction="mean"):
    """
    Negative log likelihood (loss function) of gaussian random variable

    : param y_true: 	target value
    : param mu: 		mean of distribution ... size = (batch_size, output_dim)
    : param sigma: 		standard deviation of distribution  ... size (batch_size, 1) <– same variance

    returns mean loss per sample (not per point)
    """
    dim = mu.shape[1]/2
    var = sigma.pow(2).squeeze()

    if reduction=="mean":
        return - (torch.mean(torch.sum((y_true-mu).pow(2), axis=1)/(2*var) + dim*torch.log(var)) + dim*math.log(2*math.pi))
    if reduction=="none":
        return - (torch.sum((y_true-mu).pow(2), axis=1)/(2*var) + dim*torch.log(var) + dim*math.log(2*math.pi))

In [145]:
class SS_SVI(nn.Module):
    def __init__(self, model, likelihood="GaussianNLL", **kwargs):
        super(SS_SVI, self).__init__()
        self.model = model
        self.one_hot = One_Hot(n_classes=self.model.y_dim)

        if likelihood not in ["BCE", "GaussianNLL", "MSE"]:
            raise ValueError("Unknown likelihood")
        else:
            self.likelihood = likelihood
        
        self.device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") if kwargs.get("set_device") == None else kwargs.get("set_device")
        
        self.encoder_independent_of_y = kwargs.get("encoder_independent_of_y") if kwargs.get("encoder_independent_of_y") != None else True 

    def reconstruction_loss(self, y_pred, y_true, reduction="mean"):
        if self.likelihood == "BCE":
            reconstruction_loss = nn.BCELoss(reduction='sum')(y_pred, y_true)
        elif self.likelihood == "GaussianNLL":
            assert len(y_pred)==2
            mu_out, sigma_out = y_pred
            reconstruction_loss = gaussian_nll(y_true=y_true, mu=mu_out, sigma=sigma_out, reduction=reduction)
            #self.loss_history["MSE"].append(sample_mse(y_true=y_true, y_pred=mu_out).item())
        elif self.loss_fn == "MSE":
            reconstruction_loss = sample_mse(y_true=y_true, y_pred=y_pred)
            #nn.MSELoss(reduction='sum')(y_pred, y_true)
        return reconstruction_loss

    def forward(self, X, y=None):
        supervised = False if y is None else True

        #	y-> z 
        #	|   |
        #   ->x<-

        if supervised:
            p_y_pred = self.model.classify(X)
        
            y_oh = self.one_hot(y).to(self.device)
            if self.encoder_independent_of_y:
                z, z_mu, z_sigma = self.model.encoder(X) 
            else:
                z, z_mu, z_sigma = self.model.encoder(torch.cat([X,y_oh], axis=1)) 

            X_hat = self.model.decoder(torch.cat([z, y_oh], axis=1))

            # losses
            # classification loss
            loss_clf = nn.CrossEntropyLoss()(p_y_pred, y)

            # log p_y of prior distribution
            y_prior = 1/self.model.y_dim *torch.ones_like(y_oh)
            log_py = torch.mean(torch.log(y_prior), axis=1).to(self.device)

            # K-L divergence
            kld = - 0.5 * torch.sum(1 + torch.log(z_sigma.pow(2)) - z_mu.pow(2) - z_sigma.pow(2), axis=1)
            # with reduction ->  
            #    kld = - 0.5 * torch.mean(torch.sum(1 + torch.log(z_sigma.pow(2)) - z_mu.pow(2) - z_sigma.pow(2), axis=1))
            
            # log p_x ... reconstruction error
            log_px = self.reconstruction_loss(X_hat, X, reduction="none")

            # final loss of current flow
            likelihood = log_px + log_py - kld 

            return torch.mean(likelihood), loss_clf # returns scalar losses

        else:
            X_expanded = torch.cat(self.model.y_dim*[X]).float()

            # E[q(y|x)] = sum q(y|x) <- monte carlo improvement <- inaccurate decisions on start of training
            y_oh = []
            for i in range(self.model.y_dim):
                y_oh.append(i*torch.ones(X.shape[0]))
            y_oh_expanded = self.one_hot(torch.cat(y_oh,axis=0).long()).to(self.device)

            z, z_mu, z_sigma = self.model.encoder(torch.cat([X_expanded,y_oh_expanded], axis=1)) 

            X_hat = self.model.decoder(torch.cat([z, y_oh_expanded.float()], axis=1))

            y_pred = self.model.classify(X)
            p_y_pred = F.softmax(y_pred, dim=1)

            # losses
            kld = - 0.5 * torch.sum(1 + torch.log(z_sigma.pow(2)) - z_mu.pow(2) - z_sigma.pow(2), axis=1)

            log_px = self.reconstruction_loss(X_hat, X_expanded, reduction="none")
            
            """ podle mě stačí jen log -> that should be same
            y_prior = 1/self.model.y_dim *torch.ones_like(y_oh_expanded) #y_prior = 1/self.model.y_dim *torch.ones_like(p_y_pred)
            log_py = - nn.CrossEntropyLoss(reduction="none")(y_prior, torch.argmax(y_oh_expanded, axis=1)) #nn.CrossEntropyLoss(y_prior, p_y_pred, reduction="none")
            #https://github.com/wohlert/semi-supervised-pytorch/blob/master/semi-supervised/inference/distributions.py
            #https://github.com/wohlert/semi-supervised-pytorch/blob/master/semi-supervised/inference/variational.py
            """
            
            log_py = torch.log(torch.tensor(1./self.model.y_dim))
            
            likelihood = log_px + log_py - kld 

            likelihood = torch.mul(p_y_pred, likelihood.view(self.model.y_dim, X.shape[0]).T - torch.log(p_y_pred+1e-8))

            likelihood = torch.sum(likelihood, axis=1)

            return torch.mean(likelihood)


In [134]:
from IPython.core.display import display, HTML
display(HTML("<style>.container { width:80% !important; }</style>"))

In [146]:
class Generative_Model_Trainer(nn.Module):
    def __init__(self, model, optimizer, scheduler=None, lr=1e-3, **kwargs):
        super(Generative_Model_Trainer, self).__init__()
        self.model = model
        self.optimizer = optimizer(self.model.parameters(), lr=1e-3)
        
        self.device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") if kwargs.get("set_device")==None else kwargs.get("set_device")
        print("Using device {}".format(self.device))
        
        self.elbo = SS_SVI(self.model, likelihood="GaussianNLL", set_device=self.device)
        self.scheduler = scheduler

        #optional params
        if kwargs.get("tensorboard") == True:
            self.tensorboard = True
            if kwargs.get("model_name")!= None:
                self.tb = SummaryWriter(comment=kwargs.get("model_name"))
            else:
                self.tb = SummaryWriter()
        else: 
            self.tensorboard = False
        
        self.elbo.encoder_independent_of_y = kwargs.get("encoder_independent_of_y") if kwargs.get("encoder_independent_of_y") != None else True 
        
        self.model = self.model.to(self.device)
        self.elbo = self.elbo.to(self.device)
        
        self.verbose = kwargs.get("verbose") if kwargs.get("verbose") != None else False
        if self.verbose:
            print(self)
        

    def reset_losses(self):
        self.loss_history = {
                            "train_total_loss":0., 
                            "train_classifier_loss":0., 
                            "train_supervised_loss":0., 
                            "train_unsupervised_loss":0.,
                            "validation_total_loss":0.,
                            "validation_classifier_loss":0.,
                            "validation_supervised_loss":0.,
                            "validation_unsupervised_loss":0.,
                            "validation_accuracy":0.
                             }
    
    def tensorboard_push_losses(self, epoch, n_train_batches, n_valid_batches):
        """
        function for saving losses to tensorboard
        """
        self.tb.add_scalar("Loss/train_total_loss", self.loss_history["train_total_loss"]/n_train_batches, epoch)
        self.tb.add_scalar("Loss/train_classifier_loss", self.loss_history["train_classifier_loss"]/n_train_batches, epoch)
        self.tb.add_scalar("Loss/train_supervised_loss", self.loss_history["train_supervised_loss"]/n_train_batches, epoch)
        self.tb.add_scalar("Loss/train_unsupervised_loss", self.loss_history["train_unsupervised_loss"]/n_train_batches, epoch)
        
        self.tb.add_scalar("Loss/validation_total_loss", self.loss_history["validation_total_loss"]/n_valid_batches, epoch)
        self.tb.add_scalar("Loss/validation_classifier_loss", self.loss_history["validation_classifier_loss"]/n_valid_batches, epoch)
        self.tb.add_scalar("Loss/validation_supervised_loss", self.loss_history["validation_supervised_loss"]/n_valid_batches, epoch)
        self.tb.add_scalar("Loss/validation_unsupervised_loss", self.loss_history["validation_unsupervised_loss"]/n_valid_batches, epoch)
        self.tb.add_scalar("Accuracy/validation". self.loss_history["validation_accuracy"]/n_valid_batches, epoch)


    def forward(self, epochs, supervised_dataset, unsupervised_dataset, validation_dataset, batch_size):
        if not isinstance(epochs, range):
            epochs = range(epochs)
        n_epochs = max(epochs)+1

        unsupervised = torch.utils.data.DataLoader(
                            dataset=torch.tensor(unsupervised_dataset).float(), 
                            batch_size=batch_size//2, 
                            shuffle=False, 
                            sampler=torch.utils.data.RandomSampler(
                                unsupervised_dataset, 
                                replacement=False
                                )
                            )
        supervised = torch.utils.data.DataLoader(
                            dataset=supervised_dataset,
                            batch_size=batch_size//2, 
                            shuffle=False, 
                            sampler=torch.utils.data.RandomSampler(
                                supervised_dataset, 
                                replacement=True,
                                num_samples=unsupervised_dataset.shape[0]
                                )
                            )

        validation = torch.utils.data.DataLoader(dataset=validation_dataset, batch_size=batch_size)

        for epoch in epochs:
            self.model.train()
            self.reset_losses()
            # ============== training ==============
            for i in range(len(unsupervised)):
                X_unsup = next(iter(unsupervised))
                X_sup, y_sup = next(iter(supervised))

                X_unsup = X_unsup.to(self.device)
                X_sup = X_sup.to(self.device)
                y_sup = y_sup.to(self.device)

                # ============== forward ===========
                L, CLF = self.elbo(X_sup, y_sup)
                L = -L
                U = -self.elbo(X_unsup)

                alpha = 0.1 * (batch_size//2)*2  # correction -> even numbers
                J = L + U + alpha*CLF

                # logging losses
                self.loss_history["train_total_loss"] += J.detach().item()
                self.loss_history["train_supervised_loss"] += L.detach().item()
                self.loss_history["train_classifier_loss"] += CLF.detach().item()
                self.loss_history["train_unsupervised_loss"] += U.detach().item()
                
                #print(f"J: {np.exp(J.detach().item())}, L: {L.detach().item()}, U: {U.detach().item()}, CLF: {CLF.detach().item()}")

                # ============ backward ============
                self.optimizer.zero_grad()
                J.backward()
                self.optimizer.step()

            # ============= validation =============
            self.model.eval()
            with torch.no_grad():
                acc = 0
                for x, y in validation:
                    x_sup = x[:batch_size//2]
                    y_sup = y[:batch_size//2]
                    x_unsup = x[batch_size//2:]
                    
                    x_sup = x_sup.to(self.device)
                    y_sup = y_sup.to(self.device)
                    x_unsup = x_unsup.to(self.device)

                    # ============== forward ===========
                    L, CLF = self.elbo(x_sup, y_sup)
                    L = -L
                    U = -self.elbo(x_unsup)

                    alpha = 0.1 * (batch_size//2)*2  # correction -> even numbers
                    J = L + U + alpha*CLF
                    # logging losses
                    self.loss_history["validation_total_loss"] += J.detach().item()
                    self.loss_history["validation_supervised_loss"] += L.detach().item()
                    self.loss_history["validation_classifier_loss"] += CLF.detach().item()
                    self.loss_history["validation_unsupervised_loss"] += U.detach().item()
                    
                    # classification
                    x = x.to(self.device)
                    y = y.to(self.device)
                    y_valid_pred = self.model.classify(x)
                    self.loss_history["validation_accuracy"] += accuracy_score(y.cpu().detach(), torch.argmax(y_valid_pred.cpu().detach(), axis=1))
            if self.verbose:
                print("Epoch [{}/{}], average_loss:{:.4f}, validation_loss:{:.4f}, val_accuracy:{:,.4f}"\
                        .format(epoch+1, n_epochs,self.loss_history["train_total_loss"]/len(unsupervised), self.loss_history["validation_total_loss"]/len(validation), self.loss_history["validation_accuracy"]/len(validation)))
                    
            if self.tensorboard:
                self.tensorboard_push_losses(epoch=epoch, n_train_batches=len(unsupervised), n_valid_batches=len(validation))
                
            if self.scheduler!=None:
                self.scheduler.step()



# Zadefinovat M2 model

In [30]:
class M2(nn.Module):
    def __init__(self, latent_dim, n_classes):
        super(M2, self).__init__()
        self.y_dim = n_classes
        self.encoder = nn.Sequential(
                            nn.Linear(in_features=160+4, out_features=400),
                            nn.ELU(),
                            nn.Linear(in_features=400, out_features=200),
                            nn.ELU(),
                            layers.VariationalLayer(in_features=200, out_features=latent_dim, return_KL=False)
                            )
        self.decoder = nn.Sequential(
                            nn.Linear(in_features=latent_dim+4, out_features=200),
                            nn.ELU(),
                            nn.Linear(in_features=200, out_features=400),
                            nn.ELU(),
                            layers.VariationalDecoderOutput(in_features=400, out_features=160)
                            )
        self.classify = nn.Sequential(
                            nn.Linear(in_features=160, out_features=400),
                            nn.ELU(),
                            nn.Linear(in_features=400, out_features=200),
                            nn.ELU(),
                            nn.Linear(in_features=200, out_features=4)
                            )

# Připravit data

In [13]:
X = np.vstack((np.load("data/sequenced_data_for_VAE_length-160_stride-10_pt1.npy"),
               np.load("data/sequenced_data_for_VAE_length-160_stride-10_pt2.npy")))
y = np.load("data/sequenced_data_for_VAE_length-160_stride-10_targets.npy")

In [14]:
print(X.shape, y.shape)

(83680, 160) (83680,)


In [15]:
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=666)

In [16]:
scaler = RobustScaler()
X_train = scaler.fit_transform(X_train)
X_test = scaler.transform(X_test)

In [17]:
X_unsup, X_sup, y_unsup, y_sup = train_test_split(X_train, y_train, test_size=0.2, random_state=666)

In [18]:
sup = d.H_alphaSequences(X_sup, y_sup)
test = d.H_alphaSequences(X_test, y_test)

In [16]:
#sup_loader = torch.utils.data.DataLoader(dataset = sup, batch_size=512, shuffle=True)
#unsup_loader = torch.utils.data.DataLoader(dataset = X_unsup, batch_size=512, shuffle=True)
#test_loader = torch.utils.data.DataLoader(dataset = test, batch_size=512, shuffle=True)

In [17]:
unsupervised = torch.utils.data.DataLoader(
                    dataset=X_unsup, 
                    batch_size=64, 
                    shuffle=False, 
                    sampler=torch.utils.data.RandomSampler(
                        X_unsup, 
                        replacement=False
                        )
                    )
supervised = torch.utils.data.DataLoader(
                    dataset=sup,
                    batch_size=64, 
                    shuffle=False, 
                    sampler=torch.utils.data.RandomSampler(
                        sup, 
                        replacement=True,
                        num_samples=X_unsup.shape[0]
                        )
                    )

In [18]:
next(iter(unsupervised))

tensor([[ 0.5261,  0.5227,  0.5327,  ...,  0.4995,  0.5320,  0.5514],
        [-0.5435, -0.5480, -0.5401,  ..., -0.1841, -0.1531, -0.1185],
        [-0.1430, -0.1262, -0.0768,  ..., -0.0498, -0.0645, -0.0490],
        ...,
        [ 1.2675,  1.2575,  1.2475,  ...,  1.1055,  1.1334,  1.2037],
        [ 0.1804,  0.1748,  0.1696,  ...,  0.1585,  0.1507,  0.1609],
        [-0.6210, -0.6033, -0.5980,  ..., -0.0828, -0.0999, -0.1116]],
       dtype=torch.float64)

In [19]:
next(iter(supervised))

[tensor([[ 0.0599,  0.0448,  0.0088,  ..., -0.0847, -0.0775, -0.0940],
         [ 0.4667,  0.5947,  0.6855,  ...,  0.7528,  0.8496,  0.8665],
         [ 0.2817,  0.2601,  0.2753,  ...,  0.3532,  0.3645,  0.3327],
         ...,
         [ 1.9015,  1.9384,  2.0870,  ...,  1.7680,  1.7477,  1.7215],
         [-0.3499, -0.3711, -0.3812,  ..., -0.4057, -0.4134, -0.4139],
         [-0.2339, -0.2697, -0.2638,  ..., -0.6638, -0.6342, -0.6662]]),
 tensor([1, 1, 1, 1, 0, 0, 2, 1, 1, 1, 1, 1, 3, 1, 0, 1, 1, 0, 1, 1, 1, 1, 1, 0,
         1, 1, 1, 0, 1, 0, 1, 1, 1, 1, 1, 0, 1, 1, 1, 1, 1, 1, 0, 1, 2, 1, 1, 1,
         1, 1, 1, 0, 1, 1, 1, 1, 3, 1, 0, 1, 1, 1, 1, 0])]

# kontrola dimenzí

In [20]:
model = M2(30,4)

In [21]:
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)
elbo = SS_SVI(model, likelihood="GaussianNLL")

In [23]:
for i in range(10):
    X_unsup = next(iter(unsupervised))
    X_sup, y_sup = next(iter(supervised))

    # ============== forward ===========
    L, CLF = elbo(X_sup, y_sup)
    L = -L
    U = -elbo(X_unsup.float())

    alpha = 0.1 * 128
    J = L + U + alpha*CLF
    # ============= backward==============
    optimizer.zero_grad()
    J.backward()
    optimizer.step()
    print(f"J: {J.item()}, L: {L.item()}, U: {U.item()}, Clf: {CLF.item()}")



J: 305.3947448730469, L: 140.54623413085938, U: 151.11927795410156, Clf: 1.072594165802002
J: 304.8121643066406, L: 129.4605255126953, U: 159.42337036132812, Clf: 1.24439537525177
J: 324.5542907714844, L: 157.43711853027344, U: 154.90997314453125, Clf: 0.9536867141723633
J: 275.02392578125, L: 149.09475708007812, U: 113.04313659667969, Clf: 1.0067213773727417
J: 253.57777404785156, L: 127.16284942626953, U: 112.63513946533203, Clf: 1.0765453577041626
J: 258.19879150390625, L: 106.16460418701172, U: 139.84983825683594, Clf: 0.951903760433197
J: 204.4355926513672, L: 101.17811584472656, U: 90.52375793457031, Clf: 0.9948216676712036
J: 197.70603942871094, L: 91.07245635986328, U: 94.13410949707031, Clf: 0.9765209555625916
J: 214.12820434570312, L: 129.99945068359375, U: 73.35533142089844, Clf: 0.8416734933853149
J: 164.689697265625, L: 68.43266296386719, U: 85.19239807128906, Clf: 0.8644246459007263


In [25]:
for i in range(100):
    X_unsup = next(iter(unsupervised))
    X_sup, y_sup = next(iter(supervised))

    # ============== forward ===========
    L, CLF = elbo(X_sup, y_sup)
    L = -L
    U = -elbo(X_unsup.float())

    alpha = 0.1 * 128
    J = L + U + alpha*CLF
    # ============= backward==============
    optimizer.zero_grad()
    J.backward()
    optimizer.step()
    print(f"J: {J.item()}, L: {L.item()}, U: {U.item()}, Clf: {CLF.item()}")

J: 41.21916580200195, L: 7.773998737335205, U: 22.048633575439453, Clf: 0.8903540968894958
J: 140.17601013183594, L: 110.40641021728516, U: 18.319795608520508, Clf: 0.894515335559845
J: 52.332054138183594, L: 32.473018646240234, U: 8.90212631225586, Clf: 0.8560085892677307
J: 33.08345413208008, L: 0.8606748580932617, U: 20.63958740234375, Clf: 0.9049367904663086
J: 48.522491455078125, L: 27.207210540771484, U: 8.677841186523438, Clf: 0.9873000383377075
J: 29.506622314453125, L: 12.614885330200195, U: 5.646051406860352, Clf: 0.8785691261291504
J: 16.454254150390625, L: 2.20969820022583, U: 2.203416347503662, Clf: 0.9407141208648682
J: 13.802461624145508, L: -9.472404479980469, U: 12.043622016906738, Clf: 0.8774409294128418
J: 4.527091979980469, L: -21.11874771118164, U: 16.887664794921875, Clf: 0.6842324137687683
J: 25.184553146362305, L: 21.503582000732422, U: -8.526070594787598, Clf: 0.9536750912666321
J: -7.714491844177246, L: -5.908435344696045, U: -13.083610534667969, Clf: 0.881058

J: -182.5281219482422, L: -103.3153076171875, U: -87.80377197265625, Clf: 0.6711688041687012
J: -121.03032684326172, L: -62.10957336425781, U: -69.95521545410156, Clf: 0.862067461013794
J: -200.43798828125, L: -130.47694396972656, U: -78.63784790039062, Clf: 0.6778758764266968
J: -167.4507598876953, L: -104.78825378417969, U: -70.606689453125, Clf: 0.6206389665603638
J: -190.52044677734375, L: -118.4954833984375, U: -81.23854064941406, Clf: 0.7198104858398438
J: -160.8734588623047, L: -71.12211608886719, U: -97.76535034179688, Clf: 0.6260940432548523
J: -222.58456420898438, L: -123.96791076660156, U: -108.58290100097656, Clf: 0.7786133289337158
J: -177.65866088867188, L: -86.23480224609375, U: -101.72545623779297, Clf: 0.8048129081726074
J: -157.54981994628906, L: -70.93718719482422, U: -99.75155639648438, Clf: 1.0264791250228882
J: -171.90257263183594, L: -116.13007354736328, U: -61.68026351928711, Clf: 0.4615439772605896


In [41]:
L

tensor(244.8381, grad_fn=<NegBackward>)

In [42]:
CLF

tensor(1.3806, grad_fn=<NllLossBackward>)

In [43]:
J

tensor(469.2866, grad_fn=<AddBackward0>)

# TEST


In [147]:
model = M2(30,4)

In [148]:
gmt = Generative_Model_Trainer(model=model, optimizer=torch.optim.Adam, scheduler=None, lr=1e-3, tensorboard=False, verbose=True)

Using device cuda:0
Generative_Model_Trainer(
  (model): M2(
    (encoder): Sequential(
      (0): Linear(in_features=164, out_features=400, bias=True)
      (1): ELU(alpha=1.0)
      (2): Linear(in_features=400, out_features=200, bias=True)
      (3): ELU(alpha=1.0)
      (4): VariationalLayer(
        (mu): Linear(in_features=200, out_features=30, bias=True)
        (rho): Linear(in_features=200, out_features=30, bias=True)
        (softplus): Softplus(beta=1, threshold=20)
      )
    )
    (decoder): Sequential(
      (0): Linear(in_features=34, out_features=200, bias=True)
      (1): ELU(alpha=1.0)
      (2): Linear(in_features=200, out_features=400, bias=True)
      (3): ELU(alpha=1.0)
      (4): VariationalDecoderOutput(
        (mu): Linear(in_features=400, out_features=160, bias=True)
        (rho): Linear(in_features=400, out_features=1, bias=True)
        (softplus): Softplus(beta=1, threshold=20)
      )
    )
    (classify): Sequential(
      (0): Linear(in_features=160, o

In [108]:
gmt(epochs=range(20), supervised_dataset=sup, unsupervised_dataset = X_unsup, validation_dataset=test, batch_size=128)

Epoch [1/20], average_loss:-247.5030, validation_loss:-340.2869, val_accuracy:0.7785
Epoch [2/20], average_loss:-375.4871, validation_loss:-390.9652, val_accuracy:0.8009
Epoch [3/20], average_loss:-399.0818, validation_loss:-409.1270, val_accuracy:0.8010
Epoch [4/20], average_loss:-416.2138, validation_loss:-428.5248, val_accuracy:0.8179
Epoch [5/20], average_loss:-435.0976, validation_loss:-404.9739, val_accuracy:0.8142
Epoch [6/20], average_loss:-447.3335, validation_loss:-461.6093, val_accuracy:0.8072
Epoch [7/20], average_loss:-462.6327, validation_loss:-467.0352, val_accuracy:0.8170
Epoch [8/20], average_loss:-467.5099, validation_loss:-474.8092, val_accuracy:0.8212
Epoch [9/20], average_loss:-473.7803, validation_loss:-479.2372, val_accuracy:0.8282
Epoch [10/20], average_loss:-483.4562, validation_loss:-491.3083, val_accuracy:0.8117
Epoch [11/20], average_loss:-492.1682, validation_loss:-495.8349, val_accuracy:0.8322
Epoch [12/20], average_loss:-499.3201, validation_loss:-498.443

{'train_total_loss': -441190.11865234375,
 'train_classifier_loss': 243.34739600121975,
 'train_supervised_loss': -222824.86700439453,
 'train_unsupervised_loss': -221480.09805297852,
 'validation_total_loss': -68991.56201171875,
 'validation_classifier_loss': 47.51626372337341,
 'validation_supervised_loss': -34807.221252441406,
 'validation_unsupervised_loss': -34792.54878234863,
 'validation_accuracy': 113.72395833333333}

tensor(22026.4648)

In [119]:
temp = False

In [120]:
tmp = 1 \
    if temp==True else 0

In [121]:
tmp

0