In [5]:
import torch.backends.cudnn as cudnn
import torch
import torchvision.transforms as transforms
import PIL
import argparse
import os
import random
import sys
import pprint
import datetime
import dateutil.tz
import numpy as np
import json
import functools

In [33]:
!pip3 install -r requirements.txt

Looking in indexes: https://pypi.org/simple, https://pypi.ngc.nvidia.com
Collecting easydict
  Downloading easydict-1.9.tar.gz (6.4 kB)
Building wheels for collected packages: easydict
  Building wheel for easydict (setup.py) ... [?25ldone
[?25h  Created wheel for easydict: filename=easydict-1.9-py3-none-any.whl size=6361 sha256=43f1c23d07465a0f628126b913bc8511786b87697172915f02c2f6b9228b2f85
  Stored in directory: /tmp/pip-ephem-wheel-cache-croolvsj/wheels/d3/e0/e9/305e348717e399665119bd012510d51ff4f22d709ff60c3096
Successfully built easydict
Installing collected packages: easydict
Successfully installed easydict-1.9


In [34]:
from storygen.config import cfg, cfg_from_file

In [4]:
random.seed(0)
torch.manual_seed(0)
if cfg.CUDA:
    print('CUDA Flag enabled: ', cfg.CUDA)
    torch.cuda.manual_seed_all(0)
now = datetime.datetime.now(dateutil.tz.tzlocal())
timestamp = now.strftime('%Y_%m_%d_%H_%M_%S')
output_dir = './output/%s_%s' % (cfg.DATASET_NAME, cfg.CONFIG_NAME)

# number of gpus
num_gpu = len(cfg.GPU_ID.split(','))
print("number of GPUs: ", num_gpu)

CUDA Flag enabled:  True
number of GPUs:  1


# Transform

Basic Data Manipulation

In [35]:
if cfg.TRAIN.FLAG:
    print('TRAIN FLAG ENABLED:', cfg.TRAIN.FLAG)
    image_transforms = transforms.Compose([PIL.Image.fromarray, 
                                           transforms.Resize((cfg.IMSIZE, cfg.IMSIZE)),
                                           #transforms.RandomHorizontalFlip(),
                                           transforms.ToTensor(),
                                           transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])
        # dataset = TextDataset(cfg.DATA_DIR, 'train',
        #                       imsize=cfg.IMSIZE,
        #                       transform=image_transform)
        #assert dataset
    def video_transform(video, image_transform):
        vid = []
        for im in video:
            vid.append(image_transform(im))
        vid = torch.stack(vid).permute(1, 0, 2, 3)
        print("vid value: ", vid)
        return vid

    video_len = 5
    n_channels = 3
    # functools.partial takes methods/functions as an input
    video_transforms = functools.partial(video_transform, image_transform=image_transforms)

TRAIN FLAG ENABLED: True


## GAN, Data

In [6]:
import storygen.pororo_data as data
from storygen.train import gan_trainer

In [7]:
dir_path = "./pororo_data/"
counter = np.load(os.path.join(dir_path, 'frames_counter.npy'), allow_pickle=True).item()
print("The number of frames: ", len(counter))
base = data.VideoFolderDataset(dir_path, counter = counter, cache = dir_path, min_len = 4, mode="train")
storydataset = data.StoryDataset(base, dir_path, video_transforms)
imagedataset = data.ImageDataset(base, dir_path, image_transforms)

The number of frames:  183
Total number of clips 10191


In [8]:
# number of gpus
num_gpu = len(cfg.GPU_ID.split(','))
## dataloader
imageloader = torch.utils.data.DataLoader(imagedataset, batch_size=cfg.TRAIN.IM_BATCH_SIZE * num_gpu,
                                          drop_last=True, shuffle=True, num_workers=int(cfg.WORKERS))
print("imageloader length: ", len(imageloader))
storyloader = torch.utils.data.DataLoader(storydataset, batch_size=cfg.TRAIN.ST_BATCH_SIZE * num_gpu,
                                          drop_last=True, shuffle=True, num_workers=int(cfg.WORKERS))
print("storyloader length: ", len(storyloader))

## Validation
val_dir_path = dir_path
base_val = data.VideoFolderDataset(val_dir_path, counter, val_dir_path, 4, mode="val")
valdataset = data.StoryDataset(base_val, val_dir_path, video_transforms)
valloader = torch.utils.data.DataLoader(valdataset, batch_size=20, 
                                         drop_last=True, shuffle=False, num_workers=int(cfg.WORKERS))
print("Validation loader length: ", len(valloader))

imageloader length:  159
storyloader length:  159
Total number of clips 2320
Validation loader length:  116


## Train GAN

In [9]:
output_dir = './model'
algo = gan_trainer(cfg, output_dir, ratio = 1.0)
algo.train(imageloader, storyloader, valloader, cfg.STAGE)

ImportError: cannot import name 'StoryGAN' from 'model' (unknown location)

In [35]:
def sample_real_image_batch(imageloader):
    imagedataset = None
    if imagedataset is None:
        imagedataset = enumerate(imageloader)
    batch_idx, batch = next(imagedataset)
    #print("%d %s" % (batch_idx, batch))
    
    b = batch
    if cfg.CUDA:
        for k, v in batch.items():
            if k == 'text':
                continue
            else:
                b[k] = v.cuda()
    
    if batch_idx == len(imageloader) - 1:
        imagedataset = enumerate(imageloader)
    return b

0 {'images': tensor([[[[ 0.1608,  0.1451,  0.1294,  ...,  0.3255,  0.5294,  0.4588],
          [ 0.1451,  0.1529,  0.2157,  ...,  0.3804,  0.5059,  0.4353],
          [ 0.1294,  0.1765,  0.0980,  ...,  0.4824,  0.5216,  0.4196],
          ...,
          [ 0.8667,  0.7333,  0.6078,  ...,  0.6627,  0.6549,  0.6392],
          [ 0.9059,  0.8980,  0.8824,  ...,  0.6471,  0.6471,  0.5765],
          [ 0.9059,  0.9059,  0.9059,  ...,  0.6549,  0.6235,  0.3804]],

         [[ 0.0902,  0.0824,  0.0824,  ...,  0.2392,  0.4118,  0.3333],
          [ 0.0745,  0.0745,  0.1294,  ...,  0.3020,  0.4039,  0.3490],
          [ 0.0667,  0.1373,  0.2471,  ...,  0.3882,  0.4196,  0.3490],
          ...,
          [ 0.7255,  0.5843,  0.4510,  ...,  0.5373,  0.5373,  0.5137],
          [ 0.7569,  0.7490,  0.7333,  ...,  0.4980,  0.5137,  0.4431],
          [ 0.7569,  0.7569,  0.7569,  ...,  0.5216,  0.4902,  0.2471]],

         [[-0.0980, -0.0980, -0.1059,  ..., -0.1922,  0.0353, -0.0588],
          [-0.098

In [28]:
import os, pickle
from tqdm import tqdm
import numpy as np
import torch.utils.data
import PIL
from random import randrange
from collections import Counter
import nltk
import json
import torchvision.transforms as transforms

unique_characters = ["Wilma", "Fred", "Betty", "Barney", "Dino", "Pebbles", "Mr Slate"]
class VideoFolderDataset(torch.utils.data.Dataset):
    def __init__(self, folder, cache=None, min_len=4, mode='train'):
        self.lengths = []
        self.followings = {}
        self.dir_path = folder
        self.total_frames = 0

        # train_id, test_id = np.load(self.dir_path + 'train_test_ids.npy', allow_pickle=True, encoding='latin1')
        splits = json.load(open(os.path.join(self.dir_path, 'train-val-test_split.json'), 'r'))
        train_id, val_id, test_id = splits["train"], splits["val"], splits["test"]

        if os.path.exists(cache + 'following_cache' + str(min_len) +  '.npy'):
            self.followings = pickle.load(open(cache + 'following_cache' + str(min_len) + '.pkl', 'rb'))
        else:
            all_clips = train_id + val_id + test_id
            all_clips.sort()
            for idx, clip in enumerate(all_clips):
                season, episode = int(clip.split('_')[1]), int(clip.split('_')[3])
                has_frames = True
                for c in all_clips[idx+1:idx+min_len+1]:
                    s_c, e_c = int(c.split('_')[1]), int(c.split('_')[3])
                    if s_c != season or e_c != episode:
                        has_frames = False
                        break
                if has_frames:
                    self.followings[clip] = all_clips[idx+1:idx+min_len+1]
                else:
                    continue
            pickle.dump(self.followings, open(os.path.join(folder, 'following_cache' + str(min_len) + '.pkl'), 'wb'))



        if os.path.exists(os.path.join(folder, 'labels.pkl')):
            self.labels = pickle.load(open(os.path.join(folder, 'labels.pkl'), 'rb'))
        else:
            print("Computing and saving labels")
            annotations = json.load(open(os.path.join(folder, 'flintstones_annotations_v1-0.json'), 'r'))
            self.labels = {}
            for sample in annotations:
                sample_characters = [c["entityLabel"].strip().lower() for c in sample["characters"]]
                self.labels[sample["globalID"]] = [1 if c.lower() in sample_characters else 0 for c in unique_characters]
            pickle.dump(self.labels, open(os.path.join(folder, 'labels.pkl'), 'wb'))

        self.embeds = np.load(os.path.join(self.dir_path, "flintstones_use_embeddings.npy"))
        self.sent2idx = pickle.load(open(os.path.join(self.dir_path, 'flintstones_use_embed_idxs.pkl'), 'rb'))

        self.filtered_followings = {}
        for i, f in self.followings.items():
            #print(f)
            if len(f) == 4:
                self.filtered_followings[i] = f
            else:
                continue
        self.followings = self.filtered_followings

        train_id = [tid for tid in train_id if tid in self.followings]
        val_id = [vid for vid in val_id if vid in self.followings]
        test_id = [tid for tid in test_id if tid in self.followings]

        if mode == 'train':
            self.orders = train_id
        elif mode =='val':
            self.orders = val_id
        elif mode == 'test':
            self.orders = test_id
        else:
            raise ValueError
        print("Total number of clips {}".format(len(self.orders)))

    def sample_image(self, im):
        shorter, longer = min(im.size[0], im.size[1]), max(im.size[0], im.size[1])
        video_len = int(longer/shorter)
        se = np.random.randint(0, video_len, 1)[0]
        #print(se*shorter, shorter, (se+1)*shorter)
        return im.crop((0, se * shorter, shorter, (se+1)*shorter)), se

    def __getitem__(self, item):
        return [self.orders[item]] + self.followings[self.orders[item]]

    def __len__(self):
        return len(self.orders)


class StoryDataset(torch.utils.data.Dataset):
    def __init__(self, dataset, transform, return_caption=False, out_dir=None, densecap=False):
        self.dir_path = dataset.dir_path
        self.dataset = dataset
        self.transforms = transform
        self.labels = dataset.labels
        self.return_caption = return_caption

        annotations = json.load(open(os.path.join(self.dir_path, 'flintstones_annotations_v1-0.json')))
        self.descriptions = {}
        for sample in annotations:
            self.descriptions[sample["globalID"]] = sample["description"]

        if self.return_caption:
            self.init_mart_vocab()
            self.max_len = self.tokenize_descriptions()
            print("Max sequence length = %s" % self.max_len)
        else:
            self.vocab = None
        self.out_dir = out_dir

        # if densecap:
        #     self.densecap_dataset = DenseCapDataset(self.dir_path)
        # else:
        self.densecap_dataset = None

    def tokenize_descriptions(self):
        caption_lengths = []
        self.tokenized_descriptions = {}
        for img_id, descs in self.descriptions.items():
            self.tokenized_descriptions[img_id] = nltk.tokenize.word_tokenize(descs.lower())
            caption_lengths.append(len(self.tokenized_descriptions[img_id]))
        return max(caption_lengths) + 2

    def init_mart_vocab(self):

        vocab_file = os.path.join(self.dir_path, 'mart_vocab.pkl')
        if os.path.exists(vocab_file):
            vocab_from_file = True
        else:
            vocab_from_file = False

        self.vocab = Vocabulary(vocab_threshold=5,
                                vocab_file=vocab_file,
                                annotations_file=os.path.join(self.dir_path, 'flintstones_annotations_v1-0.json'),
                                vocab_from_file=vocab_from_file)

    def save_story(self, output, save_path = './'):
        all_image = []
        images = output['images_numpy']
        texts = output['text']
        for i in range(images.shape[0]):
            all_image.append(np.squeeze(images[i]))
        output = PIL.Image.fromarray(np.concatenate(all_image, axis = 0))
        output.save(save_path + 'image.png')
        fid = open(save_path + 'text.txt', 'w')
        for i in range(len(texts)):
            fid.write(texts[i] +'\n' )
        fid.close()
        return

    def _sentence_to_idx(self, sentence_tokens):
        """[BOS], [WORD1], [WORD2], ..., [WORDN], [EOS], [PAD], ..., [PAD], len == max_t_len
        All non-PAD values are valid, with a mask value of 1
        """
        max_t_len = self.max_len
        sentence_tokens = sentence_tokens[:max_t_len - 2]

        # pad
        valid_l = len(sentence_tokens)
        mask = [1] * valid_l + [0] * (max_t_len - valid_l)
        sentence_tokens += [self.vocab.pad_word] * (max_t_len - valid_l)
        input_ids = [self.vocab.word2idx.get(t, self.vocab.word2idx[self.vocab.unk_word]) for t in sentence_tokens]

        return input_ids, mask

    def __getitem__(self, item):
        lists = self.dataset[item]
        labels = []
        images = []
        text = []
        input_ids = []
        masks= []
        sent_embeds = []
        for idx, globalID in enumerate(lists):
            if self.out_dir:
                im = PIL.Image.open(os.path.join(self.out_dir, 'img-%s-%s.png' % (item, idx))).convert('RGB')
            else:
                arr = np.load(os.path.join(self.dir_path, 'video_frames_sampled', globalID + '.npy'))
                n_frames = arr.shape[0]
                im = arr[randrange(n_frames)]
            images.append(np.expand_dims(np.array(im), axis=0))
            text.append(self.descriptions[globalID])
            labels.append(np.expand_dims(self.labels[globalID], axis = 0))
            sent_embeds.append(np.expand_dims(self.dataset.embeds[self.dataset.sent2idx[globalID]], axis = 0))

            if self.return_caption:
                input_id, mask = self._sentence_to_idx(self.tokenized_descriptions[globalID])
                input_ids.append(np.expand_dims(input_id, axis=0))
                masks.append(np.expand_dims(mask, axis=0))

        sent_embeds = np.concatenate(sent_embeds, axis = 0)
        labels = np.concatenate(labels, axis = 0)
        images = np.concatenate(images, axis = 0)
        # image is T x H x W x C
        transformed_images = self.transforms(images)
        # After transform, image is C x T x H x W

        sent_embeds = torch.tensor(sent_embeds)
        labels = torch.tensor(np.array(labels).astype(np.float32))

        data_item = {'images': transformed_images, 'text':text, 'description': sent_embeds, 'images_numpy':images, 'labels':labels}

        if self.return_caption:
            input_ids = torch.tensor(np.concatenate(input_ids))
            masks = torch.tensor(np.concatenate(masks))
            data_item.update({'input_ids': input_ids, 'masks': masks})

        if self.densecap_dataset:
            boxes, caps, caps_len = [], [], []
            for idx, v in enumerate(lists):
                img_id = str(v).replace('.png', '')[2:-1]
                path = img_id + '.png'
                boxes.append(torch.as_tensor([ann['box'] for ann in self.densecap_dataset[path]], dtype=torch.float32))
                caps.append(torch.as_tensor([ann['cap_idx'] for ann in self.densecap_dataset[path]], dtype=torch.long))
                caps_len.append(torch.as_tensor([sum([1 for k in ann['cap_idx'] if k!= 0]) for ann in self.densecap_dataset[path]], dtype=torch.long))
            targets = {
                'boxes': torch.cat(boxes),
                'caps': torch.cat(caps),
                'caps_len': torch.cat(caps_len),
            }
            data_item.update(targets)

        return data_item

    def __len__(self):
        return len(self.dataset.orders)


class ImageDataset(torch.utils.data.Dataset):
    def __init__(self, dataset, transform, return_caption=False, densecap=False):
        self.dir_path = dataset.dir_path
        self.dataset = dataset
        self.transforms = transform
        self.labels = dataset.labels
        self.return_caption = return_caption

        annotations = json.load(open(os.path.join(self.dir_path, 'flintstones_annotations_v1-0.json')))
        self.descriptions = {}
        for sample in annotations:
            self.descriptions[sample["globalID"]] = sample["description"]

        if self.return_caption:
            self.init_mart_vocab()
            self.max_len = self.tokenize_descriptions()
            print("Max sequence length = %s" % self.max_len)
        else:
            self.vocab = None

        # if densecap:
        #     self.densecap_dataset = DenseCapDataset(self.dir_path)
        # else:
        self.densecap_dataset = None

    def tokenize_descriptions(self):
        caption_lengths = []
        self.tokenized_descriptions = {}
        for img_id, descs in self.descriptions.items():
            self.tokenized_descriptions[img_id] = nltk.tokenize.word_tokenize(descs.lower())
            caption_lengths.append(len(self.tokenized_descriptions[img_id]))
        return max(caption_lengths) + 2

    def _sentence_to_idx(self, sentence_tokens):
        """[BOS], [WORD1], [WORD2], ..., [WORDN], [EOS], [PAD], ..., [PAD], len == max_t_len
        All non-PAD values are valid, with a mask value of 1
        """
        max_t_len = self.max_len
        sentence_tokens = sentence_tokens[:max_t_len - 2]

        # pad
        valid_l = len(sentence_tokens)
        mask = [1] * valid_l + [0] * (max_t_len - valid_l)
        sentence_tokens += [self.vocab.pad_word] * (max_t_len - valid_l)
        input_ids = [self.vocab.word2idx.get(t, self.vocab.word2idx[self.vocab.unk_word]) for t in sentence_tokens]

        return input_ids, mask

    def init_mart_vocab(self):

        vocab_file = os.path.join(self.dir_path, 'mart_vocab.pkl')
        if os.path.exists(vocab_file):
            vocab_from_file = True
        else:
            vocab_from_file = False

        self.vocab = Vocabulary(vocab_threshold=5,
                                vocab_file=vocab_file,
                                annotations_file=os.path.join(self.dir_path, 'flintstones_annotations_v1-0.json'),
                                vocab_from_file=vocab_from_file)

    def __getitem__(self, item):

        # single image input
        globalID = self.dataset[item][0]
        arr = np.load(os.path.join(self.dir_path, 'video_frames_sampled', globalID + '.npy'))
        n_frames = arr.shape[0]
        im = arr[randrange(n_frames)]

        image = np.array(im)
        text = self.descriptions[globalID]
        label = np.array(self.labels[globalID]).astype(np.float32)
        sent_embed = self.dataset.embeds[self.dataset.sent2idx[globalID]]

        input_id = None
        mask = None
        if self.return_caption:
            input_id, mask = self._sentence_to_idx(self.tokenized_descriptions[globalID])
            input_id = np.array(input_id)
            mask = np.array(mask)

        # input ofr conditional vector
        lists = self.dataset[item]
        sent_embeds = []
        for idx, globalID in enumerate(lists):
            sent_embeds.append(np.expand_dims(self.dataset.embeds[self.dataset.sent2idx[globalID]], axis=0))
        sent_embeds = np.concatenate(sent_embeds, axis=0)

        ##
        sent_embeds = torch.tensor(sent_embeds)
        image = self.transforms(image)
        data_item = {'images': image, 'text':text, 'description': sent_embed,
                     'labels':label, 'content': sent_embeds}

        if self.return_caption:
            input_id = torch.tensor(input_id)
            mask = torch.tensor(mask)
            data_item.update({'input_id': input_id, 'mask':mask})

        if self.densecap_dataset:
            path = globalID + '.png'
            try:
                _ = self.densecap_dataset[path]
            except KeyError:
                shorter, longer = min(im.size[0], im.size[1]), max(im.size[0], im.size[1])
                video_len = int(longer / shorter)
                raise KeyError

            boxes = torch.as_tensor([ann['box'] for ann in self.densecap_dataset[path]], dtype=torch.float32)
            caps = torch.as_tensor([ann['cap_idx'] for ann in self.densecap_dataset[path]], dtype=torch.long)
            caps_len = torch.as_tensor([sum([1 for k in ann['cap_idx'] if k!= 0]) for ann in self.densecap_dataset[path]], dtype=torch.long)
            targets = {
                'boxes': boxes,
                'caps': caps,
                'caps_len': caps_len,
            }
            data_item.update(targets)

        return data_item

    def __len__(self):
        return len(self.dataset.orders)


class StoryImageDataset(torch.utils.data.Dataset):
    def __init__(self, data_folder, im_input_size,
                 out_img_folder = None,
                 mode='train',
                 video_len = 5,
                 transform=None):
        self.followings = {}
        self.data_folder = data_folder
        self.labels = pickle.load(open(os.path.join(data_folder, 'labels.pkl'), 'rb'))
        self.video_len = video_len
        min_len = video_len-1

        splits = json.load(open(os.path.join(self.data_folder, 'train-val-test_split.json'), 'r'))
        train_ids, val_ids, test_ids = splits["train"], splits["val"], splits["test"]

        if os.path.exists(os.path.join(data_folder, 'following_cache' + str(video_len-1) +  '.pkl')):
            self.followings = pickle.load(open(os.path.join(data_folder, 'following_cache' + str(video_len-1) + '.pkl'), 'rb'))
        else:
            all_clips = train_ids + val_ids + test_ids
            all_clips.sort()
            for idx, clip in enumerate(tqdm(all_clips, desc="Counting total number of frames")):
                season, episode = int(clip.split('_')[1]), int(clip.split('_')[3])
                has_frames = True
                for c in all_clips[idx+1:idx+min_len+1]:
                    s_c, e_c = int(c.split('_')[1]), int(c.split('_')[3])
                    if s_c != season or e_c != episode:
                        has_frames = False
                        break
                if has_frames:
                    self.followings[clip] = all_clips[idx+1:idx+min_len+1]
                else:
                    continue
            pickle.dump(self.followings, open(os.path.join(self.data_folder, 'following_cache' + str(min_len) + '.pkl'), 'wb'))

        self.filtered_followings = {}
        for i, f in self.followings.items():
            #print(f)
            if len(f) == 4:
                self.filtered_followings[i] = f
            else:
                continue
        self.followings = self.filtered_followings

        train_ids = [tid for tid in train_ids if tid in self.followings]
        val_ids = [vid for vid in val_ids if vid in self.followings]
        test_ids = [tid for tid in test_ids if tid in self.followings]

        # print(list(self.followings.keys())[:10])

        if mode == 'train':
            self.ids = train_ids
            self.transform = transforms.Compose([
                # Image.fromarray,
                transforms.Resize(im_input_size),
                transforms.RandomHorizontalFlip(),
                transforms.ToTensor(),
                # transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
            ])
        else:
            self.ids = val_ids[:2060] if mode == "val" else test_ids[:2304]
            self.transform = transforms.Compose([
                # Image.fromarray,
                transforms.Resize(im_input_size),
                transforms.CenterCrop(im_input_size),
                transforms.ToTensor(),
                # transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
            ])


        self.out_dir = out_img_folder

    def __getitem__(self, item):

        globalIDs = [self.ids[item]] + self.followings[self.ids[item]]

        images = []
        for idx, globalID in enumerate(globalIDs):
            if self.out_dir:
                im = PIL.Image.open(os.path.join(self.out_dir, 'img-%s-%s.png' % (item, idx))).convert('RGB')
                images.append(im)
            else:
                arr = np.load(os.path.join(self.data_folder, 'video_frames', globalID + '.npy'))
                n_frames = arr.shape[0]
                im = arr[randrange(n_frames)]
                # images.append(np.expand_dims(np.array(im), axis = 0))
                images.append(PIL.Image.fromarray(im))

        # print([(type(im)) for im in images])

        labels = [self.labels[globalID] for globalID in globalIDs]
        return torch.stack([self.transform(image).squeeze(0) for image in images]), torch.tensor(np.vstack(labels))

    def __len__(self):
        return len(self.ids)

In [29]:
import tqdm
dir_path = "./flintstones_data/"
base = VideoFolderDataset(dir_path, cache = "./flintstones_data/", min_len = 4, mode="train")

Total number of clips 20132


In [36]:
#storydataset = StoryDataset(base, video_transforms)
imagedataset = ImageDataset(base, image_transforms)

In [37]:
next(iter(imagedataset))

{'images': tensor([[[-0.4353, -0.4353, -0.4745,  ...,  0.1059, -0.0275,  0.2549],
          [-0.4510, -0.4745, -0.4510,  ...,  0.1765,  0.1216,  0.2314],
          [-0.4824, -0.5059, -0.4902,  ...,  0.0118,  0.1294,  0.1529],
          ...,
          [-0.8588, -0.9216, -0.9451,  ...,  0.2078,  0.0745, -0.1451],
          [-0.7333, -0.9137, -0.9373,  ...,  0.2314,  0.1216, -0.0275],
          [-0.2706, -0.7255, -0.8902,  ...,  0.2078,  0.0667, -0.0039]],
 
         [[-0.4510, -0.4510, -0.5216,  ..., -0.0902, -0.2157,  0.0667],
          [-0.4667, -0.4980, -0.5059,  ..., -0.0275, -0.0431,  0.0745],
          [-0.4745, -0.5137, -0.5294,  ..., -0.1686, -0.0353,  0.0039],
          ...,
          [-0.3255, -0.3882, -0.4275,  ...,  0.0588, -0.0431, -0.2314],
          [-0.4745, -0.4824, -0.4353,  ...,  0.0902,  0.0196, -0.1137],
          [-0.1922, -0.4588, -0.4902,  ...,  0.0667, -0.0275, -0.0980]],
 
         [[-0.6471, -0.6549, -0.7020,  ..., -0.4510, -0.5608, -0.2863],
          [-0.6706