In [None]:
import torch
import numpy as np
from alae import MappingD, MappingF, EncoderNoStyle, EncoderDefault, Discriminator

In [None]:
latent_size = 128
mapping_layers = 5
layer_count = 3
mapping_f = MappingF(
    num_layers= 2 * layer_count,
    latent_size=latent_size,
    dlatent_size=latent_size,
    mapping_fmaps=latent_size,
    mapping_layers=mapping_layers
)
mapping_d = MappingD(
    latent_size=latent_size,
    dlatent_size=latent_size,
    mapping_fmaps=latent_size,
    mapping_layers=3,
)

encoder = EncoderDefault(
    startf=32,
    maxf=128,
    latent_size=latent_size,
    layer_count=3,
)


In [None]:

encoder_2 = EncoderNoStyle(
    startf=4, maxf=128, layer_count=5, latent_size=latent_size
)

In [None]:
inputs = torch.randn(3, 4, 4)

In [None]:
encoder(inputs, 0, 1).shape


In [None]:
encoder(inputs, 2, 1).shape

In [None]:
encoder_2(inputs, 0, 1)

In [None]:
z = torch.randn(32, 128)

In [None]:
styles = mapping_f(z)
styles2 = styles[:,0]
styles2 = styles2.view(styles2.shape[0], 1, styles2.shape[1]).repeat(1, mapping_f.num_layers, 1)

In [None]:
styles - styles2

In [None]:
styles.shape

In [None]:
styles2.shape

In [None]:
layer_idx = torch.arange(mapping_f.num_layers)[
                    np.newaxis, :, np.newaxis]

In [None]:
layer_idx

In [None]:
a = torch.lerp(torch.Tensor([1.0]),  torch.Tensor([0.0]))

In [None]:
x = torch.randn(10,3,32,32)
m = torch.mean(x, dim=[2, 3], keepdim=True)
std = torch.sqrt(torch.mean((x - m) ** 2, dim=[2, 3], keepdim=True))
style_1 = torch.cat((m, std), dim=1)

In [None]:
style_1.shape

In [None]:
m.shape

In [2]:
from defaults import get_cfg_defaults
import os
import imageio
import random
import tqdm
import numpy as np
from utils import get_main_directory

cfg = get_cfg_defaults()

In [4]:
data_path = os.path.join(get_main_directory(), "data/")
dataset_name = "COVID-19_Radiography_Dataset"
data_list = ["COVID", "Normal"]
count = []
if not os.path.exists(data_path):
    print("------Downloading Dataset from Kaggle------")
    os.system(
        f"kaggle datasets download tawsifurrahman/covid19-radiography-database -p {data_path} --unzip")
if not os.path.exists(os.path.dirname(cfg.DATASET.PATH)):
    os.makedirs(os.path.dirname(cfg.DATASET.PATH))

for r in range(2, cfg.DATASET.MAX_RESOLUTION_LEVEL+1):
    img_size = 2*r
    for data_dir in data_list:
        root, dirs, files = next(os.walk(os.path.join(data_path, dataset_name, data_dir), topdown=False))
        count.append(len(files))
        random.shuffle(files)
        files = files[:min(count)]
        images = []
        for file in files:
            image = imageio.imread(root + "/" + file)
            image = np.resize(image, (img_size,img_size))
            # put images in dict
            images.append(image)

  image = imageio.imread(root + "/" + file)


[[162 162 162 162]
 [162 162 162 162]
 [162 162 162 162]
 [162 162 161 161]]
uint8
[[  6   7   8   8]
 [  7   7   8   8]
 [ 10   4   7  62]
 [154 167  67  10]]
uint8
[[ 81  82  84  86]
 [ 87  89  91  95]
 [ 97 101  98  96]
 [103 105 111 117]]
uint8
[[0 0 0 0]
 [0 0 0 0]
 [0 0 0 0]
 [0 0 0 0]]
uint8
[[187 196 197 194]
 [187 180 186 191]
 [167 138 118 103]
 [ 85  70  74  78]]
uint8
[[ 5  5  6  7]
 [ 7  7  8  8]
 [ 9 10 10 10]
 [10 11 12 13]]
uint8
[[125  31   5   8]
 [  9   7   5   4]
 [  4   4   4   4]
 [  4   3   3   2]]
uint8
[[0 0 0 0]
 [0 0 0 0]
 [0 0 0 0]
 [0 0 0 0]]
uint8
[[11  7  3  0]
 [ 0  1  1  0]
 [ 0  0  0  0]
 [ 0  0  0  0]]
uint8
[[85 62 51 52]
 [54 49 46 47]
 [44 41 41 39]
 [34 31 32 33]]
uint8
[[0 0 0 0]
 [0 0 0 0]
 [0 0 0 0]
 [0 0 0 0]]
uint8
[[12 15 14 14]
 [16 18 20 25]
 [27 28 30 33]
 [34 36 37 39]]
uint8
[[12 12 12 12]
 [12 12 12 12]
 [12 12 12 12]
 [12 12 12 12]]
uint8
[[187 155 105  74]
 [ 67  70  72  72]
 [ 74  81 107 134]
 [135 114  92  94]]
uint8
[[0 0 0 0]
 [0

KeyboardInterrupt: 

In [33]:
import torch
from torch.utils.data import DataLoader
from dataloader import CovidTfRecordDataset
from defaults import get_cfg_defaults
import logging
import numpy as np


logger = logging.getLogger()
cfg = get_cfg_defaults()

dataset = CovidTfRecordDataset(cfg, logger)
dataloader = DataLoader(dataset, batch_size=32)


In [41]:
dataset.reset(5)
img_size = 2**5

In [42]:
for data in dataloader:
    img_shape = data["shape"]

    image = data["data"]
    print(type(img_shape))
    print(np.frombuffer(image[0],dtype=np.uint8))
    tensor = torch.from_numpy(
                np.concatenate([
                    np.frombuffer(
                        i,
                        dtype=np.uint8).reshape(1, cfg.MODEL.CHANNELS, img_size, img_size)
                    for i in data['data']],
                    axis=0))

<class 'torch.Tensor'>
[  0   0  23 ... 201 192 189]
<class 'torch.Tensor'>
[ 2  0  0 ... 70  4  0]
<class 'torch.Tensor'>
[3 3 3 ... 3 2 2]
<class 'torch.Tensor'>
[30 26 27 ... 24  0  0]
<class 'torch.Tensor'>
[0 0 0 ... 0 0 0]
<class 'torch.Tensor'>
[ 0  0  0 ... 26  0  0]
<class 'torch.Tensor'>
[0 0 1 ... 0 1 1]
<class 'torch.Tensor'>
[23 23 28 ... 67  0  0]
<class 'torch.Tensor'>
[ 1  1  1 ... 23 27 29]
<class 'torch.Tensor'>
[ 0 38 37 ...  2  1  1]
<class 'torch.Tensor'>
[ 2  3  3 ... 79 13  0]
<class 'torch.Tensor'>
[  2   2   2 ... 124  76   2]
<class 'torch.Tensor'>
[178   2  12 ... 206 199 105]
<class 'torch.Tensor'>
[  1   1   1 ... 189 161  90]
<class 'torch.Tensor'>
[10  8 10 ...  0  8 16]
<class 'torch.Tensor'>
[0 0 0 ... 0 0 0]
<class 'torch.Tensor'>
[ 22  17  16 ... 192 133  52]
<class 'torch.Tensor'>
[  0   0   0 ... 100  12   0]
<class 'torch.Tensor'>
[ 96 124 117 ...   3   1   1]
<class 'torch.Tensor'>
[17 18 30 ...  3  0  0]
<class 'torch.Tensor'>
[ 8  0  0 ... 94 68

In [1]:
import torch
import os
from PIL import Image
from torchvision.utils import save_image
from model import Model
from custom_adam import LREQAdam
from scheduler import ComboMultiStepLR
from tracker import LossTracker
from checkpointer import Checkpointer
import lod_driver
from tqdm import tqdm
import numpy as np
from dataloader import *
import torch.nn.functional as F
from defaults import get_cfg_defaults
from launcher import run
import utils
import logging
from torch.utils.data import DataLoader

cfg = get_cfg_defaults()
logger = logging.getLogger()


def save_sample(lod2batch, tracker, sample, samplez, x, logger, model, cmodel, cfg, encoder_optimizer, decoder_optimizer):
    os.makedirs('results', exist_ok=True)
# Save Sample LOL
    logger.info('\n[%d/%d] - ptime: %.2f, %s, blend: %.3f, lr: %.12f,  %.12f, max mem: %f",' % (
        (lod2batch.current_epoch +
         1), cfg.TRAIN.TRAIN_EPOCHS, lod2batch.per_epoch_ptime, str(tracker),
        lod2batch.get_blend_factor(),
        encoder_optimizer.param_groups[0]['lr'], decoder_optimizer.param_groups[0]['lr'],
        torch.cuda.max_memory_allocated() / 1024.0 / 1024.0))

    with torch.no_grad():
        model.eval()
        cmodel.eval()
        sample = sample[:lod2batch.get_per_GPU_batch_size()]
        samplez = samplez[:lod2batch.get_per_GPU_batch_size()]

        needed_resolution = model.decoder.layer_to_resolution[lod2batch.lod]
        sample_in = sample
        while sample_in.shape[2] > needed_resolution:
            sample_in = F.avg_pool2d(sample_in, 2, 2)
        assert sample_in.shape[2] == needed_resolution

        blend_factor = lod2batch.get_blend_factor()
        if lod2batch.in_transition:
            needed_resolution_prev = model.decoder.layer_to_resolution[lod2batch.lod - 1]
            sample_in_prev = F.avg_pool2d(sample_in, 2, 2)
            sample_in_prev_2x = F.interpolate(
                sample_in_prev, needed_resolution)
            sample_in = sample_in * blend_factor + \
                sample_in_prev_2x * (1.0 - blend_factor)

        Z, _ = model.encode(sample_in, lod2batch.lod, blend_factor)

        if cfg.MODEL.Z_REGRESSION:
            Z = model.mapping_f(Z[:, 0])
        else:
            Z = Z.repeat(1, model.mapping_f.num_layers, 1)

        rec1 = model.decoder(Z, lod2batch.lod, blend_factor, noise=True)
        rec2 = cmodel.decoder(Z, lod2batch.lod, blend_factor, noise=True)

        Z = model.mapping_f(samplez)
        g_rec = model.decoder(Z, lod2batch.lod, blend_factor, noise=True)

        Z = cmodel.mapping_f(samplez)
        cg_rec = cmodel.decoder(Z, lod2batch.lod, blend_factor, noise=True)

        resultsample = torch.cat([sample_in, rec1, rec2, g_rec, cg_rec], dim=0)

        @utils.async_func
        def save_pic(x_rec):
            tracker.register_means(
                lod2batch.current_epoch + lod2batch.iteration * 1.0 / lod2batch.get_dataset_size())
            tracker.plot()

            result_sample = x_rec * 0.5 + 0.5
            result_sample = result_sample.cpu()
            f = os.path.join(cfg.OUTPUT_DIR,
                             'sample_%d_%d.jpg' % (
                                 lod2batch.current_epoch + 1,
                                 lod2batch.iteration // 1000))
            print("Saved to %s" % f)
            save_image(result_sample, f, nrow=min(
                32, lod2batch.get_per_GPU_batch_size()))
        save_pic(resultsample)


def train(cfg, logger):
    model = Model(
        startf=cfg.MODEL.START_CHANNEL_COUNT,
        maxf=cfg.MODEL.MAX_CHANNEL_COUNT,
        layer_count=cfg.MODEL.LAYER_COUNT,
        latent_size=cfg.MODEL.LATENT_SPACE_SIZE,
        dlatent_avg_beta=cfg.MODEL.DLATENT_AVG_BETA,
        style_mixing_prob=cfg.MODEL.STYLE_MIXING_PROB,
        mapping_layers=cfg.MODEL.MAPPING_LAYERS,
        channels=cfg.MODEL.CHANNELS,
        generator=cfg.MODEL.GENERATOR,
        encoder=cfg.MODEL.ENCODER,
        z_regression=cfg.MODEL.Z_REGRESSION
    )

    device = 'cuda:0' if torch.cuda.is_available() else 'cpu'

    decoder = model.decoder
    encoder = model.encoder
    mapping_d = model.mapping_d
    mapping_f = model.mapping_f
    dlatent_avg = model.dlatent_avg

    arguments = dict()
    arguments["iteration"] = 0

    decoder_optim = LREQAdam(
        [
            {'params': decoder.parameters()},
            {'params': mapping_f.parameters()}
        ],
        lr=cfg.TRAIN.BASE_LEARNING_RATE,
        betas=(cfg.TRAIN.ADAM_BETA_0, cfg.TRAIN.ADAM_BETA_1),
        weight_decay=0
    )
    encoder_optim = LREQAdam(
        [
            {'params': encoder.parameters()},
            {'params': mapping_d.parameters()}
        ],
        lr=cfg.TRAIN.BASE_LEARNING_RATE,
        betas=(cfg.TRAIN.ADAM_BETA_0, cfg.TRAIN.ADAM_BETA_1),
        weight_decay=0
    )

    scheduler = ComboMultiStepLR(
        optimizers={
            'encoder_optimizer': encoder_optim,
            'decoder_optimizer': decoder_optim
        },
        milestones=cfg.TRAIN.LEARNING_DECAY_STEPS,
        gamma=cfg.TRAIN.LEARNING_DECAY_RATE,
        reference_batch_size=32,
        base_lr=cfg.TRAIN.LEARNING_RATES
    )

    model_dict = {
        'discriminator': encoder,
        'generator': decoder,
        'mapping_tl': mapping_d,
        'mapping_fl': mapping_f,
        'dlatent_avg': dlatent_avg
    }

    tracker = LossTracker(
        cfg.OUTPUT_DIR
    )

    checkpointer = Checkpointer(cfg,
                                model_dict,
                                {
                                    'encoder_optimizer': encoder_optim,
                                    'decoder_optimizer': decoder_optim,
                                    'scheduler': scheduler,
                                    'tracker': tracker
                                },
                                logger=logger,
                                save=True)

    extra_checkpoint_data = checkpointer.load()
    logger.info(f'Starting from epoch: {scheduler.start_epoch()}')

    arguments.update(extra_checkpoint_data)

    layer_to_resolution = decoder.layer_to_resolution

    dataset = CovidTfRecordDataset(cfg, logger)

    rnd = np.random.RandomState(0)
    latents = rnd.randn(32, cfg.MODEL.LATENT_SPACE_SIZE)
    samplez = torch.Tensor(latents).float().cuda()

    lod2batch = lod_driver.LODDriver(
        cfg,
        logger,
        world_size=1,
        dataset_size=len(dataset)
    )

    if cfg.DATASET.SAMPLES_PATH != 'no_path':
        path = cfg.DATASET.SAMPLES_PATH
        src = []
        with torch.no_grad():
            for filename in list(os.listdir(path))[:32]:
                img = np.asarray(Image.open(os.path.join(path, filename)))
                if img.shape[2] == 4:
                    img = img[:, :, :3]
                im = img.transpose((2, 0, 1))
                x = torch.Tensor(np.asarray(
                    im, dtype=np.float32), requires_grad=True).cuda() / 127.5 - 1.
                if x.shape[0] == 4:
                    x = x[:3]
                src.append(x)
                sample = torch.stack(src)
    else:
        dataset.reset(cfg.DATASET.MAX_RESOLUTION_LEVEL)
        data_batch = next(iter(DataLoader(
            dataset=dataset,
            batch_size=lod2batch.get_per_GPU_batch_size()
        )))
        img_size = 2**cfg.DATASET.MAX_RESOLUTION_LEVEL
        sample = data_batch['data']
        sample = torch.cat([torch.frombuffer(i, dtype=torch.uint8).reshape(
            1, 1, img_size, img_size) for i in sample], dim=0)
        print(sample.shape)
        sample = (sample / 127.5 - 1.)
    lod2batch.set_epoch(scheduler.start_epoch(), [
                        encoder_optim, decoder_optim])

    for epoch in range(scheduler.start_epoch(), cfg.TRAIN.TRAIN_EPOCHS):
        model.train()
        lod2batch.set_epoch(epoch, [encoder_optim, decoder_optim])

        logger.info("Batch size: %d, Batch size per GPU: %d, LOD: %d - %dx%d, blend: %.3f, dataset size: %d" % (
            lod2batch.get_batch_size(),
            lod2batch.get_per_GPU_batch_size(),
            lod2batch.lod,
            2 ** lod2batch.get_lod_power2(),
            2 ** lod2batch.get_lod_power2(),
            lod2batch.get_blend_factor(),
            len(dataset)))
        img_size = 2 ** lod2batch.get_lod_power2()
        dataset.reset(lod2batch.get_lod_power2())
        dataloader = DataLoader(
            dataset=dataset,
            batch_size=lod2batch.get_per_GPU_batch_size()
        )

        scheduler.set_batch_size(lod2batch.get_batch_size(), lod2batch.lod)

        model.train()

        need_permute = False
        epoch_start_time = time.time()

        i = 0
        for data in tqdm(dataloader):

            x_orig = torch.from_numpy(
                np.concatenate([
                    np.frombuffer(
                        i,
                        dtype=np.uint8).reshape(1, cfg.MODEL.CHANNELS, img_size, img_size)
                    for i in data['data']],
                    axis=0))
            i += 1
            with torch.no_grad():
                if x_orig.shape[0] != lod2batch.get_per_GPU_batch_size():
                    continue
                if need_permute:
                    x_orig = x_orig.permute(0, 3, 1, 2)
                x_orig = (x_orig / 127.5 - 1.)

                blend_factor = lod2batch.get_blend_factor()

                needed_resolution = layer_to_resolution[lod2batch.lod]
                x = x_orig

                if lod2batch.in_transition:
                    needed_resolution_prev = layer_to_resolution[lod2batch.lod - 1]
                    x_prev = F.avg_pool2d(x_orig, 2, 2)
                    x_prev_2x = F.interpolate(x_prev, needed_resolution)
                    x = x * blend_factor + x_prev_2x * (1.0 - blend_factor)

            x.requires_grad = True

            encoder_optim.zero_grad()
            loss_d = model(x, lod2batch.lod, blend_factor,
                           d_train=True, ae=False)
            tracker.update(dict(loss_d=loss_d))
            loss_d.backward()
            encoder_optim.step()

            decoder_optim.zero_grad()
            loss_g = model(x, lod2batch.lod, blend_factor,
                           d_train=False, ae=False)
            tracker.update(dict(loss_g=loss_g))
            loss_g.backward()
            decoder_optim.step()

            encoder_optim.zero_grad()
            decoder_optim.zero_grad()
            # this part is buggy, if ae=True and d_train=True, only ae logic will run.
            lae = model(x, lod2batch.lod, blend_factor, d_train=True, ae=True)
            tracker.update(dict(lae=lae))
            lae.backward()
            encoder_optim.step()
            decoder_optim.step()

            epoch_end_time = time.time()
            per_epoch_ptime = epoch_end_time - epoch_start_time

            lod_for_saving_model = lod2batch.lod
            lod2batch.step()
            if lod2batch.is_time_to_save():
                checkpointer.save(
                    "model_tmp_intermediate_lod%d" % lod_for_saving_model)
            if lod2batch.is_time_to_report():
                save_sample(lod2batch, tracker, sample, samplez, x, logger, model,
                            model.module if hasattr(
                                model, "module") else model, cfg, encoder_optim,
                            decoder_optim)

        scheduler.step()

        # if local_rank == 0:
        checkpointer.save("model_tmp_lod%d" % lod_for_saving_model)
        save_sample(lod2batch, tracker, sample, samplez, x, logger, model,
                    model.module if hasattr(model, "module") else model, cfg, encoder_optim, decoder_optim)

    logger.info("Training finish!... save training results")
    # if local_rank == 0:
    checkpointer.save("model_final").wait()


if __name__ == "__main__":
    gpu_count = torch.cuda.device_count()
    run(train, get_cfg_defaults(), description='StyleGAN', default_config='configs/covid.yaml',
        world_size=gpu_count)

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
train(cfg, logger)

  sample = torch.cat([torch.frombuffer(i, dtype=torch.uint8).reshape(


Epoch 00000: adjusting learning rate of group 0 to 2.0000e-03.
Epoch 00000: adjusting learning rate of group 1 to 2.0000e-03.
Epoch 00000: adjusting learning rate of group 0 to 2.0000e-03.
Epoch 00000: adjusting learning rate of group 1 to 2.0000e-03.
torch.Size([16, 1, 256, 256])


	addcmul_(Number value, Tensor tensor1, Tensor tensor2)
Consider using one of the following signatures instead:
	addcmul_(Tensor tensor1, Tensor tensor2, *, Number value) (Triggered internally at C:\actions-runner\_work\pytorch\pytorch\builder\windows\pytorch\torch\csrc\utils\python_arg_parser.cpp:1420.)
  exp_avg_sq.mul_(beta_2).addcmul_(1 - beta_2, grad, grad)


torch.Size([128, 1, 256])
torch.Size([128, 1, 256])


  4%|▍         | 1/25 [00:02<00:52,  2.20s/it]

torch.Size([128, 1, 256])
torch.Size([128, 1, 256])


  8%|▊         | 2/25 [00:04<00:51,  2.23s/it]

torch.Size([128, 1, 256])
torch.Size([128, 1, 256])


 12%|█▏        | 3/25 [00:06<00:46,  2.13s/it]

torch.Size([128, 1, 256])
torch.Size([128, 1, 256])


 16%|█▌        | 4/25 [00:08<00:44,  2.11s/it]

torch.Size([128, 1, 256])
torch.Size([128, 1, 256])


 20%|██        | 5/25 [00:10<00:42,  2.11s/it]

torch.Size([128, 1, 256])
torch.Size([128, 1, 256])


 24%|██▍       | 6/25 [00:12<00:39,  2.10s/it]

torch.Size([128, 1, 256])
torch.Size([128, 1, 256])


 28%|██▊       | 7/25 [00:14<00:37,  2.10s/it]

torch.Size([128, 1, 256])
torch.Size([128, 1, 256])


 32%|███▏      | 8/25 [00:17<00:36,  2.12s/it]

torch.Size([128, 1, 256])
torch.Size([128, 1, 256])


 36%|███▌      | 9/25 [00:19<00:33,  2.12s/it]

torch.Size([128, 1, 256])
torch.Size([128, 1, 256])


 40%|████      | 10/25 [00:21<00:33,  2.21s/it]

torch.Size([128, 1, 256])
torch.Size([128, 1, 256])


 44%|████▍     | 11/25 [00:23<00:30,  2.18s/it]

torch.Size([128, 1, 256])
torch.Size([128, 1, 256])


 48%|████▊     | 12/25 [00:26<00:29,  2.26s/it]

torch.Size([128, 1, 256])
torch.Size([128, 1, 256])


 52%|█████▏    | 13/25 [00:28<00:27,  2.27s/it]

torch.Size([128, 1, 256])
torch.Size([128, 1, 256])


 56%|█████▌    | 14/25 [00:30<00:24,  2.25s/it]

torch.Size([128, 1, 256])
torch.Size([128, 1, 256])


 60%|██████    | 15/25 [00:32<00:22,  2.23s/it]

torch.Size([128, 1, 256])
torch.Size([128, 1, 256])


 64%|██████▍   | 16/25 [00:35<00:20,  2.25s/it]

torch.Size([128, 1, 256])
torch.Size([128, 1, 256])


 68%|██████▊   | 17/25 [00:37<00:17,  2.25s/it]

torch.Size([128, 1, 256])
torch.Size([128, 1, 256])


 72%|███████▏  | 18/25 [00:39<00:15,  2.26s/it]

torch.Size([128, 1, 256])
torch.Size([128, 1, 256])


 76%|███████▌  | 19/25 [00:41<00:13,  2.24s/it]

torch.Size([128, 1, 256])
torch.Size([128, 1, 256])


 80%|████████  | 20/25 [00:44<00:11,  2.25s/it]

torch.Size([128, 1, 256])
torch.Size([128, 1, 256])


 84%|████████▍ | 21/25 [00:46<00:09,  2.25s/it]

torch.Size([128, 1, 256])
torch.Size([128, 1, 256])


 88%|████████▊ | 22/25 [00:48<00:07,  2.34s/it]

torch.Size([128, 1, 256])
torch.Size([128, 1, 256])


 92%|█████████▏| 23/25 [00:51<00:04,  2.33s/it]

torch.Size([128, 1, 256])
torch.Size([128, 1, 256])


100%|██████████| 25/25 [00:53<00:00,  2.13s/it]


Adjusting learning rate of group 0 to 2.0000e-03.
Adjusting learning rate of group 1 to 2.0000e-03.
Adjusting learning rate of group 0 to 2.0000e-03.
Adjusting learning rate of group 1 to 2.0000e-03.


RuntimeError: Expected all tensors to be on the same device, but found at least two devices, cpu and cuda:0! (when checking argument for argument mat1 in method wrapper_addmm)