In [1]:
import torch

In [5]:
ckpt = torch.load('/home/miyashita21/gitrepo/3dshapes/checkpoints/autoencoder/best-epoch=1549-val_loss=0.0052.ckpt')

In [6]:
ckpt['optimizer_states'][0]['param_groups'][0]['lr'] = 5e-4

In [7]:
torch.save(ckpt, '/home/miyashita21/gitrepo/3dshapes/checkpoints/autoencoder/best-epoch=1549-val_loss=0.0052_modified.ckpt')

In [None]:
import pytorch_lightning as pl
import torchvision as tv
import wandb

class Model(torch.nn.Module):
    def __init__(self, hidden_channels=256, latent_dim=128):
        super().__init__()
        self.encoder = torch.nn.Sequential(
            torch.nn.Conv2d(3, 32, kernel_size=3, padding=1),
            torch.nn.Softplus(),
            torch.nn.LayerNorm([32, 64, 64]),
            torch.nn.Conv2d(32, 32, kernel_size=3, padding=1),
            torch.nn.Softplus(),
            torch.nn.LayerNorm([32, 64, 64]),
            torch.nn.Conv2d(32, 64, kernel_size=4, stride=4),
            torch.nn.Softplus(),
            torch.nn.LayerNorm([64, 16, 16]),
            torch.nn.Conv2d(64, 64, kernel_size=3, padding=1),
            torch.nn.Softplus(),
            torch.nn.LayerNorm([64, 16, 16]),
            torch.nn.Conv2d(64, 64, kernel_size=3, padding=1),
            torch.nn.Softplus(),
            torch.nn.LayerNorm([64, 16, 16]),
            torch.nn.Conv2d(64, 128, kernel_size=4, stride=4),
            torch.nn.Softplus(),
            torch.nn.LayerNorm([128, 4, 4]),
            torch.nn.Conv2d(128, 128, kernel_size=3, padding=1),
            torch.nn.Softplus(),
            torch.nn.LayerNorm([128, 4, 4]),
            torch.nn.Conv2d(128, 128, kernel_size=3, padding=1),
            torch.nn.Softplus(),
            torch.nn.LayerNorm([128, 4, 4]),
            torch.nn.Conv2d(128, 256, kernel_size=4, stride=4),
            torch.nn.Softplus(),
            torch.nn.LayerNorm([256, 1, 1]),
            torch.nn.Conv2d(256, latent_dim, kernel_size=1),
        )

        self.decoder = torch.nn.Sequential(
            torch.nn.ConvTranspose2d(latent_dim, 256, kernel_size=1),
            torch.nn.Softplus(),
            torch.nn.LayerNorm([256, 1, 1]),
            torch.nn.ConvTranspose2d(256, 128, kernel_size=4, stride=4),
            torch.nn.Softplus(),
            torch.nn.LayerNorm([128, 4, 4]),
            torch.nn.ConvTranspose2d(128, 128, kernel_size=3, padding=1),
            torch.nn.Softplus(),
            torch.nn.LayerNorm([128, 4, 4]),
            torch.nn.ConvTranspose2d(128, 128, kernel_size=3, padding=1),
            torch.nn.Softplus(),
            torch.nn.LayerNorm([128, 4, 4]),
            torch.nn.ConvTranspose2d(128, 64, kernel_size=4, stride=4),
            torch.nn.Softplus(),
            torch.nn.LayerNorm([64, 16, 16]),
            torch.nn.ConvTranspose2d(64, 64, kernel_size=3, padding=1),
            torch.nn.Softplus(),
            torch.nn.LayerNorm([64, 16, 16]),
            torch.nn.ConvTranspose2d(64, 64, kernel_size=3, padding=1),
            torch.nn.Softplus(),
            torch.nn.LayerNorm([64, 16, 16]),
            torch.nn.ConvTranspose2d(64, 32, kernel_size=4, stride=4),
            torch.nn.Softplus(),
            torch.nn.LayerNorm([32, 64, 64]),
            torch.nn.ConvTranspose2d(32, 32, kernel_size=3, padding=1),
            torch.nn.Softplus(),
            torch.nn.LayerNorm([32, 64, 64]),
            torch.nn.ConvTranspose2d(32, 3, kernel_size=3, padding=1),
            torch.nn.Sigmoid()
        )

    def forward(self, x):
        z = self.encoder(x)
        x_recon = self.decoder(z)
        return x_recon


class Autoencoder(pl.LightningModule):
    def __init__(self, model, sample_num=64):
        super().__init__()
        self.model = model
        self.recon_imgs = None
            
    def on_validation_epoch_end(self):
        # 検証終了時に画像生成しwandbに記録
        if self.recon_imgs is None:
            return
        with torch.no_grad():
            img = self.recon_imgs
            grid = tv.utils.make_grid(img.cpu(), nrow=8)
            wandb_logger = self.logger
            if hasattr(wandb_logger, "experiment"):
                wandb_logger.experiment.log({f"val_generated/epoch_{self.current_epoch}": wandb.Image(grid, caption=f"epoch {self.current_epoch}")})
            self.recon_imgs = None

    def forward(self, x):
        return self.model(x)

    def training_step(self, batch, batch_idx):
        recon = self.model(batch)
        loss = torch.nn.functional.mse_loss(recon, batch)
        self.log('train_loss', loss, on_step=True, on_epoch=True, prog_bar=True)
        return loss

    def validation_step(self, batch, batch_idx):
        recon = self.model(batch)
        if self.recon_imgs is None:
            self.recon_imgs = recon.clamp(0, 1).cpu()
        loss = torch.nn.functional.mse_loss(recon, batch)
        self.log('val_loss', loss, on_step=False, on_epoch=True, prog_bar=True)
        return loss

    def configure_optimizers(self):
        optimizer = torch.optim.Adamax(self.model.parameters(), lr=1e-3, weight_decay=1e-5)
        return optimizer

model = Model(hidden_channels=64, latent_dim=256)
autoencoder = Autoencoder(model).load_state_dict(ckpt['state_dict'])



In [3]:
import h5py

# load dataset
dataset = h5py.File('3dshapes.h5', 'r')
print(dataset.keys())
images = dataset['images']
labels = dataset['labels']

<KeysViewHDF5 ['images', 'labels']>


In [2]:
labels[0]

array([  0.  ,   0.  ,   0.  ,   0.75,   0.  , -30.  ])

In [1]:
import torch
n = 192

x = torch.randn(n*n, n*n).cuda()

x = x + x.T

l, u = torch.linalg.eigh(x)

_LinAlgError: cusolver error: CUSOLVER_STATUS_INVALID_VALUE, when calling `cusolverDnXsyevd_bufferSize( handle, params, jobz, uplo, n, CUDA_R_32F, reinterpret_cast<const void*>(A), lda, CUDA_R_32F, reinterpret_cast<const void*>(W), CUDA_R_32F, workspaceInBytesOnDevice, workspaceInBytesOnHost)`. This error may appear if the input matrix contains NaN. If you keep seeing this error, you may use `torch.backends.cuda.preferred_linalg_library()` to try linear algebra operators with other supported backends. See https://pytorch.org/docs/stable/backends.html#torch.backends.cuda.preferred_linalg_library