Lien de téléchargement des données: https://cvml.ista.ac.at/AwA2/        

13GB file : https://cvml.ista.ac.at/AwA2/AwA2-data.zip

In [1]:
import sys
print(sys.version)

3.9.13 (main, Aug 25 2022, 23:51:50) [MSC v.1916 64 bit (AMD64)]


In [2]:
import sys
import cv2 # Pour utiliser open_cv, il faut la version de python est 3.7
import os
import csv

import numpy as np 
import pandas as pd 
import math

import torch 
from torch.utils.data import Dataset, DataLoader
import torchvision 
from torchvision.io import read_image
import torchvision.datasets as datasets
import torchvision.transforms as transforms

np.random.seed(0)

In [3]:
# Constant. Should be the path to the folder named JPEGImages, containing the 33K images in its subfolders.
DATA_FOLDER_PATH = "C:\\Users\\1\\Desktop\\data_ift3710\\Animals_with_Attributes2\\"
JPEGIMAGES_FOLDER_PATH = "C:\\Users\\1\\Desktop\\data_ift3710\\Animals_with_Attributes2\\JPEGImages\\"

In [4]:
# quick test
test = JPEGIMAGES_FOLDER_PATH+"fox\\fox_10001.jpg"
img = cv2.imread(test) 
print(img.shape) #ndarray
print(type(img))
cv2.imshow('Sample Image from AwA2 dataset',img)
cv2.waitKey(0)

(764, 918, 3)
<class 'numpy.ndarray'>


-1

In [5]:
labels_dirs = os.listdir(JPEGIMAGES_FOLDER_PATH)
print(labels_dirs)
len(labels_dirs) # 50 labels / subdirectories

['antelope', 'bat', 'beaver', 'blue+whale', 'bobcat', 'buffalo', 'chihuahua', 'chimpanzee', 'collie', 'cow', 'dalmatian', 'deer', 'dolphin', 'elephant', 'fox', 'german+shepherd', 'giant+panda', 'giraffe', 'gorilla', 'grizzly+bear', 'hamster', 'hippopotamus', 'horse', 'humpback+whale', 'killer+whale', 'leopard', 'lion', 'mole', 'moose', 'mouse', 'otter', 'ox', 'persian+cat', 'pig', 'polar+bear', 'rabbit', 'raccoon', 'rat', 'rhinoceros', 'seal', 'sheep', 'siamese+cat', 'skunk', 'spider+monkey', 'squirrel', 'tiger', 'walrus', 'weasel', 'wolf', 'zebra']


50

# Note : Some labels have a low number of images. 

## Possible solutions to explore : 
    Data augmentation : creating new training data by applying random transformations to existing images, such as rotating, cropping, or flipping them.

In [6]:
def find_num_images_per_label(img_dir = JPEGIMAGES_FOLDER_PATH) -> tuple[dict,dict]: 
    """ 
    USEFUL FOR SAMPLING.
    Return a dict with keys as the 50 labels, and values being the number of images in each subdirectory corresponding to label
    and a second dict with the relative numbers (proportion) for every label compared to the total number of images (useful for sampling)"""
    labels_dirs = os.listdir(img_dir)
    num_images_per_label = dict.fromkeys(labels_dirs)
    proportions_images_per_label = dict.fromkeys(labels_dirs)
    total_num_images = 0

    # Update absolute number of images per label
    for i, label in enumerate(labels_dirs) : 
        specific_label_path = os.path.join(img_dir, labels_dirs[i])
        num_images_label = len(os.listdir(specific_label_path))
        total_num_images += num_images_label
        num_images_per_label[label] = num_images_label

    # Update relative number of images per label (proportion)
    for i, label in enumerate(labels_dirs) : 
        num_images_label = num_images_per_label[label]
        proportion_label = round(num_images_label / total_num_images, 4)
        proportions_images_per_label[label] = proportion_label

    return num_images_per_label, proportions_images_per_label

num_images_per_label, proportions_images_per_label = find_num_images_per_label()
print(num_images_per_label)
print(proportions_images_per_label)

{'antelope': 1046, 'bat': 383, 'beaver': 193, 'blue+whale': 174, 'bobcat': 630, 'buffalo': 895, 'chihuahua': 567, 'chimpanzee': 728, 'collie': 1028, 'cow': 1338, 'dalmatian': 549, 'deer': 1344, 'dolphin': 946, 'elephant': 1038, 'fox': 664, 'german+shepherd': 1033, 'giant+panda': 874, 'giraffe': 1202, 'gorilla': 872, 'grizzly+bear': 852, 'hamster': 779, 'hippopotamus': 684, 'horse': 1645, 'humpback+whale': 709, 'killer+whale': 291, 'leopard': 720, 'lion': 1019, 'mole': 100, 'moose': 704, 'mouse': 185, 'otter': 758, 'ox': 728, 'persian+cat': 747, 'pig': 713, 'polar+bear': 868, 'rabbit': 1088, 'raccoon': 512, 'rat': 310, 'rhinoceros': 696, 'seal': 988, 'sheep': 1420, 'siamese+cat': 500, 'skunk': 188, 'spider+monkey': 291, 'squirrel': 1200, 'tiger': 877, 'walrus': 215, 'weasel': 272, 'wolf': 589, 'zebra': 1170}
{'antelope': 0.028, 'bat': 0.0103, 'beaver': 0.0052, 'blue+whale': 0.0047, 'bobcat': 0.0169, 'buffalo': 0.024, 'chihuahua': 0.0152, 'chimpanzee': 0.0195, 'collie': 0.0275, 'cow': 0.

In [7]:
ANNOTATIONS_FILENAME = 'annotations.csv'

def create_annotations_csv_file(annotations_filename = ANNOTATIONS_FILENAME, img_dir = JPEGIMAGES_FOLDER_PATH): 
    """ 
    Create a csv annotations_file, annotations.csv, with two columns, in the format : 
                        path/to/image, label
    
    The annotation csv is necessary for DataLoader.
    """
    
    labels_dirs:list = os.listdir(img_dir)
   
    if os.path.exists(annotations_filename):
        os.remove(annotations_filename)
        print(f'Deleted existent {ANNOTATIONS_FILENAME} file.\n ---------------------------')
    
    with open(annotations_filename, 'w', newline='') as file :
        writer = csv.writer(file, dialect='excel', delimiter=',')

        for i, label in enumerate(labels_dirs) : 

            specific_label_path = os.path.join(img_dir, label)
            images_names = os.listdir(specific_label_path)

            for j, image_name in enumerate(images_names):
                full_path_to_img= os.path.join(specific_label_path, image_name)
                full_path_to_img= os.path.join(label, image_name)

                row = [full_path_to_img, label]
                writer.writerow(row)

    print(f'Sucessfully created {ANNOTATIONS_FILENAME} file.')

#
create_annotations_csv_file()

Deleted existent annotations.csv file.
 ---------------------------
Sucessfully created annotations.csv file.


In [8]:
# labels_in_number = pd.read_csv(DATA_FOLDER_PATH+"classes.txt", delim_whitespace=True,header=None)
labels_dict = {}
with open(DATA_FOLDER_PATH+"classes.txt") as f:
    for line in f:
        # print(line.split())
        (key,val) = line.split()
        labels_dict[val] = int(key)-1
print(labels_dict)

{'antelope': 0, 'grizzly+bear': 1, 'killer+whale': 2, 'beaver': 3, 'dalmatian': 4, 'persian+cat': 5, 'horse': 6, 'german+shepherd': 7, 'blue+whale': 8, 'siamese+cat': 9, 'skunk': 10, 'mole': 11, 'tiger': 12, 'hippopotamus': 13, 'leopard': 14, 'moose': 15, 'spider+monkey': 16, 'humpback+whale': 17, 'elephant': 18, 'gorilla': 19, 'ox': 20, 'fox': 21, 'sheep': 22, 'seal': 23, 'chimpanzee': 24, 'hamster': 25, 'squirrel': 26, 'rhinoceros': 27, 'rabbit': 28, 'bat': 29, 'giraffe': 30, 'wolf': 31, 'chihuahua': 32, 'rat': 33, 'weasel': 34, 'otter': 35, 'buffalo': 36, 'zebra': 37, 'giant+panda': 38, 'deer': 39, 'bobcat': 40, 'pig': 41, 'lion': 42, 'mouse': 43, 'polar+bear': 44, 'collie': 45, 'walrus': 46, 'raccoon': 47, 'cow': 48, 'dolphin': 49}


In [9]:
class AWA2Dataset(Dataset): # Dataset class to serve as input for the DataLoader.
    """ 
    Dataset class to serve as input for the DataLoader.
    Implements all the required methods and more. 
    """

    def __init__(self, annotations_file=ANNOTATIONS_FILENAME, img_dir=JPEGIMAGES_FOLDER_PATH, 
                transform=None, target_transform=None):
        self.img_labels = pd.read_csv(annotations_file)
        self.img_dir = img_dir
        self.transform = transform
        self.target_transform = target_transform

        numbers_infos_dicts: tuple[dict,dict] = find_num_images_per_label(img_dir=JPEGIMAGES_FOLDER_PATH)
        self.num_images_per_label = numbers_infos_dicts[0]
        self.proportions_images_per_label = numbers_infos_dicts[1]

    def __len__(self):
        return len(self.img_labels)

    def __getitem__(self, idx):
        img_path = os.path.join(self.img_dir, self.img_labels.iloc[idx, 0])
        # img_path = self.img_labels.iloc[idx, 0]
        key = self.img_labels.iloc[idx, 1]

        # Mapping the labels from string to tensor
        label = labels_dict[key]

        image = read_image(img_path)
        if self.transform:
            image = self.transform(image)
        if self.target_transform:
            label = self.target_transform(label)
        return image, label

In [10]:
dataset = AWA2Dataset()
image,label = dataset[4125]



## TODO : Change transforms. Currently this is not useful.
dataset.transform = transforms.Compose([
                        transforms.ToPILImage(),
                        transforms.Resize((224, 224)),
                        transforms.RandomHorizontalFlip(),
                        transforms.Grayscale(num_output_channels=3),
                        transforms.ToTensor(),
                        transforms.Normalize((0.485, 0.456, 0.406), 
                                             (0.229, 0.224, 0.225))])

# Testing. All good
# random_index = np.random.randint(0, len(dataset))
# image, label = dataset[1]
# print(label)
# print(image)
train_size =  int(0.8 * len(dataset))
test_size = len(dataset) - train_size
train_dataset, test_dataset = torch.utils.data.random_split(dataset, [train_size,test_size])

In [11]:
# Experiment with DataLoader. Everything works good
dataloader = DataLoader(dataset = dataset, batch_size=4, shuffle=True)
dataiter = iter(dataloader)
data = next(dataiter)

images, labels = data 
print(labels, images.shape)

tensor([ 1, 38,  6, 15]) torch.Size([4, 3, 224, 224])


In [12]:
# Training loop example
num_epochs = 2 
batch_size = 4
total_samples = len(dataset)
n_iterations = math.ceil(total_samples/batch_size)
print(total_samples, n_iterations)

dataloader = DataLoader(dataset = dataset, batch_size=batch_size, shuffle=True)

for epoch in range(num_epochs) : 
    # loop over trainloader 
    for i, (inputs, labels) in enumerate(dataloader) : 
        
        # Do forward and backward pass, update the weights 
        if(i+1) % 5 == 0 :
            print(f'epoch {epoch+1} / {num_epochs}, step, {i+1}/{n_iterations}, inputs {inputs.shape}')

        if i==20 : 
            print('Completed')
            break


37321 9331
epoch 1 / 2, step, 5/9331, inputs torch.Size([4, 3, 224, 224])
epoch 1 / 2, step, 10/9331, inputs torch.Size([4, 3, 224, 224])
epoch 1 / 2, step, 15/9331, inputs torch.Size([4, 3, 224, 224])
epoch 1 / 2, step, 20/9331, inputs torch.Size([4, 3, 224, 224])
Completed
epoch 2 / 2, step, 5/9331, inputs torch.Size([4, 3, 224, 224])
epoch 2 / 2, step, 10/9331, inputs torch.Size([4, 3, 224, 224])
epoch 2 / 2, step, 15/9331, inputs torch.Size([4, 3, 224, 224])
epoch 2 / 2, step, 20/9331, inputs torch.Size([4, 3, 224, 224])
Completed


In [16]:
import torch.nn as nn
import timm
from vit_pytorch import ViT

# class ViT(nn.Module):
#     def __init__(self, model_name="vit_large_patch16_224_in21k", pretrained=True):
#         super(ViT, self).__init__()
#         self.vit = timm.create_model(model_name, pretrained=pretrained)
#         # Others variants of ViT can be used as well
#         '''
#         1 --- 'vit_small_patch16_224'
#         2 --- 'vit_base_patch16_224'
#         3 --- 'vit_large_patch16_224',
#         4 --- 'vit_large_patch32_224'
#         5 --- 'vit_deit_base_patch16_224'
#         6 --- 'deit_base_distilled_patch16_224',
#         '''

#         # Change the head depending of the dataset used 
#         self.vit.head = nn.Identity()
#     def forward(self, x):
#         x = self.vit.patch_embed(x)
#         cls_token = self.vit.cls_token.expand(x.shape[0], -1, -1)  
#         if self.vit.dist_token is None:
#             x = torch.cat((cls_token, x), dim=1)
#         else:
#             x = torch.cat((cls_token, self.vit.dist_token.expand(x.shape[0], -1, -1), x), dim=1)
#         x = self.vit.pos_drop(x + self.vit.pos_embed)
#         x = self.vit.blocks(x)
#         x = self.vit.norm(x)
        
#         return x[:, 0], x[:, 1:]
    
train_dataloader = DataLoader(train_dataset, batch_size=8, shuffle= True)
test_dataloader = DataLoader(test_dataset, batch_size=4, shuffle= True)

model = ViT(
    image_size = 224,
    patch_size = 16,
    num_classes = 50,
    dim = 1024,
    depth = 6,
    heads = 8,
    mlp_dim = 2048,
    dropout = 0.1,
    emb_dropout = 0.1
)
loss_fun = nn.CrossEntropyLoss()
optimizer = torch.optim.SGD(model.parameters(), lr=0.001, momentum=0.9)

num_epochs = 2
correct = 0
total = 0
for epoch in range(num_epochs):
    for i, (inputs, labels) in enumerate(train_dataloader):
        optimizer.zero_grad()
        outputs = model(inputs)
        loss = loss_fun(outputs, labels)
        loss.backward()
        optimizer.step()

        #predict
        _,predicted = torch.max(outputs.data,1)
        total += labels.size(0)
        correct += (predicted == labels).squeeze().sum().numpy()

        print(f"Epoch [{epoch+1}/{num_epochs}], Step [{i+1}/{len(train_dataloader)}], Loss: {loss.item():.4f}, accuracy: {(correct / total)*100:.4f}%")


Epoch [1/2], Step [1/3732], Loss: 4.1338, accuracy: 0.0000%
Epoch [1/2], Step [2/3732], Loss: 3.7855, accuracy: 6.2500%
Epoch [1/2], Step [3/3732], Loss: 4.0917, accuracy: 4.1667%
Epoch [1/2], Step [4/3732], Loss: 4.1595, accuracy: 3.1250%
Epoch [1/2], Step [5/3732], Loss: 3.7275, accuracy: 5.0000%
Epoch [1/2], Step [6/3732], Loss: 4.0113, accuracy: 4.1667%
Epoch [1/2], Step [7/3732], Loss: 4.0093, accuracy: 3.5714%
Epoch [1/2], Step [8/3732], Loss: 4.1892, accuracy: 3.1250%
Epoch [1/2], Step [9/3732], Loss: 4.0202, accuracy: 2.7778%
Epoch [1/2], Step [10/3732], Loss: 3.8448, accuracy: 5.0000%
Epoch [1/2], Step [11/3732], Loss: 3.8979, accuracy: 4.5455%
Epoch [1/2], Step [12/3732], Loss: 4.0302, accuracy: 5.2083%
Epoch [1/2], Step [13/3732], Loss: 4.2779, accuracy: 4.8077%


KeyboardInterrupt: 

In [14]:
# model.eval()  # switch to evaluation mode

with torch.no_grad():
    model.eval()
    correct = 0
    total = 0
    for inputs, labels in test_dataloader:
        outputs = model(inputs)
        _,predicted = torch.max(outputs.data,1)
        total += labels.size(0)
        correct += (predicted == labels).squeeze().sum().numpy()

        print(f"test accuracy: {(correct / total)*100:.4f}%")
    

test accuracy: 0.0000%
test accuracy: 12.5000%
test accuracy: 16.6667%
test accuracy: 18.7500%
test accuracy: 20.0000%
test accuracy: 16.6667%
test accuracy: 14.2857%
test accuracy: 12.5000%
test accuracy: 11.1111%
test accuracy: 10.0000%
test accuracy: 9.0909%
test accuracy: 10.4167%
test accuracy: 9.6154%
test accuracy: 8.9286%
test accuracy: 8.3333%
test accuracy: 9.3750%
test accuracy: 10.2941%
test accuracy: 11.1111%
test accuracy: 11.8421%
test accuracy: 11.2500%
test accuracy: 10.7143%
test accuracy: 12.5000%
test accuracy: 13.0435%
test accuracy: 12.5000%
test accuracy: 13.0000%
test accuracy: 12.5000%
test accuracy: 12.0370%
test accuracy: 11.6071%
test accuracy: 11.2069%
test accuracy: 11.6667%
test accuracy: 12.0968%
test accuracy: 12.5000%
test accuracy: 12.8788%
test accuracy: 13.2353%
test accuracy: 13.5714%
test accuracy: 13.1944%
test accuracy: 13.5135%
test accuracy: 13.8158%
test accuracy: 14.1026%
test accuracy: 13.7500%
test accuracy: 14.0244%
test accuracy: 13.6905

test accuracy: 16.4244%
test accuracy: 16.5217%
test accuracy: 16.5462%
test accuracy: 16.6427%
test accuracy: 16.6667%
test accuracy: 16.6905%
test accuracy: 16.6429%
test accuracy: 16.5954%
test accuracy: 16.5483%
test accuracy: 16.6431%
test accuracy: 16.7373%
test accuracy: 16.9014%
test accuracy: 16.9242%
test accuracy: 17.0168%
test accuracy: 17.1089%
test accuracy: 17.0613%
test accuracy: 17.0833%
test accuracy: 17.1745%
test accuracy: 17.2652%
test accuracy: 17.3554%
test accuracy: 17.3764%
test accuracy: 17.3288%
test accuracy: 17.3497%
test accuracy: 17.3025%
test accuracy: 17.3913%
test accuracy: 17.4119%
test accuracy: 17.3649%
test accuracy: 17.3181%
test accuracy: 17.2715%
test accuracy: 17.2252%
test accuracy: 17.2460%
test accuracy: 17.2000%
test accuracy: 17.2207%
test accuracy: 17.2414%
test accuracy: 17.1958%
test accuracy: 17.2164%
test accuracy: 17.1711%
test accuracy: 17.1260%
test accuracy: 17.2120%
test accuracy: 17.2324%
test accuracy: 17.3177%
test accuracy: 1

test accuracy: 17.1397%
test accuracy: 17.1148%
test accuracy: 17.1626%
test accuracy: 17.1739%
test accuracy: 17.1491%
test accuracy: 17.2327%
test accuracy: 17.2439%
test accuracy: 17.2550%
test accuracy: 17.2302%
test accuracy: 17.3132%
test accuracy: 17.3242%
test accuracy: 17.2994%
test accuracy: 17.2747%
test accuracy: 17.2500%
test accuracy: 17.2611%
test accuracy: 17.3077%
test accuracy: 17.3542%
test accuracy: 17.3651%
test accuracy: 17.3404%
test accuracy: 17.3867%
test accuracy: 17.3621%
test accuracy: 17.4082%
test accuracy: 17.4189%
test accuracy: 17.3944%
test accuracy: 17.4402%
test accuracy: 17.4157%
test accuracy: 17.4264%
test accuracy: 17.4020%
test accuracy: 17.3776%
test accuracy: 17.3534%
test accuracy: 17.3291%
test accuracy: 17.3747%
test accuracy: 17.3853%
test accuracy: 17.3958%
test accuracy: 17.4064%
test accuracy: 17.4169%
test accuracy: 17.3928%
test accuracy: 17.3688%
test accuracy: 17.3448%
test accuracy: 17.3209%
test accuracy: 17.3315%
test accuracy: 1

test accuracy: 17.2983%
test accuracy: 17.3058%
test accuracy: 17.2890%
test accuracy: 17.3207%
test accuracy: 17.3524%
test accuracy: 17.3356%
test accuracy: 17.3188%
test accuracy: 17.3263%
test accuracy: 17.3337%
test accuracy: 17.3170%
test accuracy: 17.3003%
test accuracy: 17.2837%
test accuracy: 17.2911%
test accuracy: 17.2985%
test accuracy: 17.2819%
test accuracy: 17.2653%
test accuracy: 17.2488%
test accuracy: 17.2801%
test accuracy: 17.2875%
test accuracy: 17.2710%
test accuracy: 17.2545%
test accuracy: 17.2619%
test accuracy: 17.2455%
test accuracy: 17.2291%
test accuracy: 17.2602%
test accuracy: 17.2438%
test accuracy: 17.2275%
test accuracy: 17.2348%
test accuracy: 17.2185%
test accuracy: 17.2023%
test accuracy: 17.2096%
test accuracy: 17.1934%
test accuracy: 17.2008%
test accuracy: 17.1846%
test accuracy: 17.2154%
test accuracy: 17.2462%
test accuracy: 17.2535%
test accuracy: 17.2608%
test accuracy: 17.2915%
test accuracy: 17.2753%
test accuracy: 17.2825%
test accuracy: 1

test accuracy: 17.2684%
test accuracy: 17.2741%
test accuracy: 17.2797%
test accuracy: 17.2853%
test accuracy: 17.2727%
test accuracy: 17.2602%
test accuracy: 17.2476%
test accuracy: 17.2533%
test accuracy: 17.2589%
test accuracy: 17.2645%
test accuracy: 17.2701%
test accuracy: 17.2757%
test accuracy: 17.2813%
test accuracy: 17.2868%
test accuracy: 17.2744%
test accuracy: 17.2619%
test accuracy: 17.2495%
test accuracy: 17.2370%
test accuracy: 17.2246%
test accuracy: 17.2122%
test accuracy: 17.1999%
test accuracy: 17.2055%
test accuracy: 17.1931%
test accuracy: 17.1987%
test accuracy: 17.2043%
test accuracy: 17.2099%
test accuracy: 17.2155%
test accuracy: 17.2210%
test accuracy: 17.2266%
test accuracy: 17.2321%
test accuracy: 17.2377%
test accuracy: 17.2254%
test accuracy: 17.2488%
test accuracy: 17.2365%
test accuracy: 17.2598%
test accuracy: 17.2475%
test accuracy: 17.2353%
test accuracy: 17.2230%
test accuracy: 17.2285%
test accuracy: 17.2340%
test accuracy: 17.2218%
test accuracy: 1

test accuracy: 17.0508%
test accuracy: 17.0408%
test accuracy: 17.0600%
test accuracy: 17.0646%
test accuracy: 17.0693%
test accuracy: 17.1030%
test accuracy: 17.1076%
test accuracy: 17.1121%
test accuracy: 17.1167%
test accuracy: 17.1068%
test accuracy: 17.1114%
test accuracy: 17.1159%
test accuracy: 17.1060%
test accuracy: 17.1106%
test accuracy: 17.1152%
test accuracy: 17.1053%
test accuracy: 17.0954%
test accuracy: 17.1144%
test accuracy: 17.1045%
test accuracy: 17.1235%
test accuracy: 17.1136%
test accuracy: 17.1037%
test accuracy: 17.1083%
test accuracy: 17.1128%
test accuracy: 17.1174%
test accuracy: 17.1075%
test accuracy: 17.1121%
test accuracy: 17.1166%
test accuracy: 17.1211%
test accuracy: 17.1113%
test accuracy: 17.1158%
test accuracy: 17.1203%
test accuracy: 17.1392%
test accuracy: 17.1294%
test accuracy: 17.1339%
test accuracy: 17.1527%
test accuracy: 17.1714%
test accuracy: 17.1616%
test accuracy: 17.1518%
test accuracy: 17.1706%
test accuracy: 17.1893%
test accuracy: 1

In [15]:
print(model)

ViT(
  (to_patch_embedding): Sequential(
    (0): Rearrange('b c (h p1) (w p2) -> b (h w) (p1 p2 c)', p1=32, p2=32)
    (1): LayerNorm((3072,), eps=1e-05, elementwise_affine=True)
    (2): Linear(in_features=3072, out_features=1024, bias=True)
    (3): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)
  )
  (dropout): Dropout(p=0.1, inplace=False)
  (transformer): Transformer(
    (layers): ModuleList(
      (0-5): 6 x ModuleList(
        (0): PreNorm(
          (norm): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)
          (fn): Attention(
            (attend): Softmax(dim=-1)
            (dropout): Dropout(p=0.1, inplace=False)
            (to_qkv): Linear(in_features=1024, out_features=1536, bias=False)
            (to_out): Sequential(
              (0): Linear(in_features=512, out_features=1024, bias=True)
              (1): Dropout(p=0.1, inplace=False)
            )
          )
        )
        (1): PreNorm(
          (norm): LayerNorm((1024,), eps=1e-05, element