In [1]:
import numpy as np
from PIL import Image

from src.data.dataset import Spatial_Coordination_Dataset
from src.data.dataset import Spatial_Free_Recall_Dataset
from src.data.dataset import Spatial_Integration_Dataset
from src.data.dataset import Spatial_Memory_Updating_Dataset
from src.data.dataset import Spatial_Task_Switching_Dataset
from src.data.dataset import Complex_WM_Dataset
from src.data.dataset import Visual_Item_Recognition_RI_2C_Dataset
from src.data.dataset import Visual_Serial_Recall_Recognition_Dataset
from src.data.dataset import Color_Orientation_Size_Gap_Conjunction_Change_Detection_Dataset

  from .autonotebook import tqdm as notebook_tqdm


### Spatial Coordination Task (SC Task)

In [2]:
sc_task_dataset = Spatial_Coordination_Dataset(data_path='./wm_bench_data/', 
                                               max_seq_len=20, 
                                               rs_img_size=96, 
                                               show_gt_pattern=True, 
                                               split='test')

Getting data for Spatial_Coordination Task
Data already exists. Skipping data generation.


In [3]:
trial_num = 1

# Load the data
sc_task_trial = sc_task_dataset[trial_num]
stim_seq, gt, gt_pattern_img, seq_len, symmetry_offset = sc_task_trial

# Generate gif of the stimulus sequence
stim_seq_gif = []
for stim in stim_seq:    
    stim_array = stim.numpy()
    stim_array = np.transpose(stim_array, (1, 2, 0))
    stim_seq_gif.append(Image.fromarray((stim_array*255).astype(np.uint8)))

stim_seq_gif = stim_seq_gif[:seq_len]

In [4]:
gif_filename = "sc_task.gif"

# Save the GIF
stim_seq_gif[0].save(
    gif_filename,
    save_all=True,
    append_images=stim_seq_gif[1:],
    loop=0,
    duration=1000,
    disposal=2,
)

print(f"GIF generated and saved as {gif_filename}")

GIF generated and saved as sc_task.gif


In [5]:
from IPython.display import Image
Image(url='./sc_task.gif', width=250)

### Spatial Free Recall Task (SFR Task)

In [9]:
from PIL import Image

In [10]:
sfr_task_dataset = Spatial_Free_Recall_Dataset(data_path='./wm_bench_data/', 
                                               max_seq_len=20, 
                                               rs_img_size=96, 
                                               split='test')

Getting data for Spatial_Free_Recall Task
Data already exists. Skipping data generation.


In [11]:
trial_num = 3000

# Load the data
sfr_task_trial = sfr_task_dataset[trial_num]
stim_seq, gt_one_hot, seq_len, recall_gt_orig = sfr_task_trial

# Generate gif of the stimulus sequence
stim_seq_gif = []
for stim in stim_seq:    
    stim_array = stim.numpy()
    stim_array = np.transpose(stim_array, (1, 2, 0))
    stim_seq_gif.append(Image.fromarray((stim_array*255).astype(np.uint8)))

stim_seq_gif = stim_seq_gif[:seq_len]

In [12]:
gif_filename = "sfr_task.gif"

# Save the GIF
stim_seq_gif[0].save(
    gif_filename,
    save_all=True,
    append_images=stim_seq_gif[1:],
    loop=0,
    duration=1000,
    disposal=2,
)

print(f"GIF generated and saved as {gif_filename}")

GIF generated and saved as sfr_task.gif


In [13]:
from IPython.display import Image

Image(url='./sfr_task.gif', width=250)

### Spatial Integration Task (SI Task)

In [14]:
from PIL import Image

In [15]:
si_task_dataset = Spatial_Integration_Dataset(data_path='./wm_bench_data/', 
                                               max_seq_len=20, 
                                               rs_img_size=96, 
                                               split='test')

Getting data for Spatial_Integration Task
Data already exists. Skipping data generation.


In [16]:
trial_num = 3000

# Load the data
si_task_trial = si_task_dataset[trial_num]
stim_seq, gt, seq_len, part_size = si_task_trial

# Generate gif of the stimulus sequence
stim_seq_gif = []
for stim in stim_seq:    
    stim_array = stim.numpy()
    stim_array = np.transpose(stim_array, (1, 2, 0))
    stim_seq_gif.append(Image.fromarray((stim_array*255).astype(np.uint8)))

stim_seq_gif = stim_seq_gif[:seq_len]

In [17]:
gif_filename = "si_task.gif"

# Save the GIF
stim_seq_gif[0].save(
    gif_filename,
    save_all=True,
    append_images=stim_seq_gif[1:],
    loop=0,
    duration=1000,
    disposal=2,
)

print(f"GIF generated and saved as {gif_filename}")

GIF generated and saved as si_task.gif


In [18]:
from IPython.display import Image

Image(url='./si_task.gif', width=250)

### Spatial Memory Updating Task (SMU Task)

In [19]:
from PIL import Image

In [20]:
smu_task_dataset = Spatial_Memory_Updating_Dataset(data_path='./wm_bench_data/', 
                                               max_seq_len=20, 
                                               rs_img_size=96, 
                                               split='test')

Getting data for Spatial_Memory_Updating Task
Data already exists. Skipping data generation.


In [21]:
trial_num = 3000

# Load the data
smu_task_trial = smu_task_dataset[trial_num]
stim_seq, gt, seq_len, set_size = smu_task_trial

# Generate gif of the stimulus sequence
stim_seq_gif = []
for stim in stim_seq:    
    stim_array = stim.numpy()
    stim_array = np.transpose(stim_array, (1, 2, 0))
    stim_seq_gif.append(Image.fromarray((stim_array*255).astype(np.uint8)))

stim_seq_gif = stim_seq_gif[:seq_len]

In [22]:
gif_filename = "smu_task.gif"

# Save the GIF
stim_seq_gif[0].save(
    gif_filename,
    save_all=True,
    append_images=stim_seq_gif[1:],
    loop=0,
    duration=1000,
    disposal=2,
)

print(f"GIF generated and saved as {gif_filename}")

GIF generated and saved as smu_task.gif


In [23]:
from IPython.display import Image

Image(url='./smu_task.gif', width=250)

### Visual Serial Recall Task (VSR Task)

In [24]:
from PIL import Image

In [27]:
vsr_task_dataset = Visual_Serial_Recall_Recognition_Dataset(data_path='./wm_bench_data/', 
                                               max_seq_len=20, 
                                               probe_variant='Recall', 
                                               rs_img_size=96, 
                                               split='test')

Getting data for Visual_Serial_Recall Task
Data already exists. Skipping data generation.


In [28]:
trial_num = 3000

# Load the data
vsr_task_trial = vsr_task_dataset[trial_num]
stim_seq, gt, seq_len, list_length = vsr_task_trial

# Generate gif of the stimulus sequence
stim_seq_gif = []
for stim in stim_seq:    
    stim_array = stim.numpy()
    stim_array = np.transpose(stim_array, (1, 2, 0))
    stim_seq_gif.append(Image.fromarray((stim_array*255).astype(np.uint8)))

stim_seq_gif = stim_seq_gif[:seq_len]

In [29]:
gif_filename = "vsr_task.gif"

# Save the GIF
stim_seq_gif[0].save(
    gif_filename,
    save_all=True,
    append_images=stim_seq_gif[1:],
    loop=0,
    duration=1000,
    disposal=2,
)

print(f"GIF generated and saved as {gif_filename}")

GIF generated and saved as vsr_task.gif


In [30]:
from IPython.display import Image

Image(url='./vsr_task.gif', width=250)

### Visual Serial Recognition Task (VSRec Task)

In [31]:
from PIL import Image

In [32]:
vsrec_task_dataset = Visual_Serial_Recall_Recognition_Dataset(data_path='./wm_bench_data/', 
                                               max_seq_len=20, 
                                               probe_variant='Recognition', 
                                               rs_img_size=96, 
                                               split='test')

Getting data for Visual_Serial_Recognition Task
Data already exists. Skipping data generation.


In [33]:
trial_num = 3000

# Load the data
vsrec_task_trial = vsrec_task_dataset[trial_num]
stim_seq, gt, seq_len, list_length, distractor_diff = vsrec_task_trial

# Generate gif of the stimulus sequence
stim_seq_gif = []
for stim in stim_seq:    
    stim_array = stim.numpy()
    stim_array = np.transpose(stim_array, (1, 2, 0))
    stim_seq_gif.append(Image.fromarray((stim_array*255).astype(np.uint8)))

stim_seq_gif = stim_seq_gif[:seq_len]

In [34]:
gif_filename = "vsrec_task.gif"

# Save the GIF
stim_seq_gif[0].save(
    gif_filename,
    save_all=True,
    append_images=stim_seq_gif[1:],
    loop=0,
    duration=1000,
    disposal=2,
)

print(f"GIF generated and saved as {gif_filename}")

GIF generated and saved as vsrec_task.gif


In [35]:
from IPython.display import Image

Image(url='./vsrec_task.gif', width=250)

### Visual Item Recognition Task (VIR Task)

In [36]:
from PIL import Image

In [37]:
vir_task_dataset = Visual_Item_Recognition_RI_2C_Dataset(data_path='./wm_bench_data/', 
                                               max_seq_len=20, 
                                               rs_img_size=96, 
                                               split='test')

Getting data for Visual_Item_Recognition_RI_2C Task
Data already exists. Skipping data generation.


In [38]:
trial_num = 3000

# Load the data
vir_task_trial = vir_task_dataset[trial_num]
stim_seq, gt, seq_len, ri, gt_index = vir_task_trial

# Generate gif of the stimulus sequence
stim_seq_gif = []
for stim in stim_seq:    
    stim_array = stim.numpy()
    stim_array = np.transpose(stim_array, (1, 2, 0))
    stim_seq_gif.append(Image.fromarray((stim_array*255).astype(np.uint8)))

stim_seq_gif = stim_seq_gif[:seq_len]

In [39]:
gif_filename = "vir_task.gif"

# Save the GIF
stim_seq_gif[0].save(
    gif_filename,
    save_all=True,
    append_images=stim_seq_gif[1:],
    loop=0,
    duration=1000,
    disposal=2,
)

print(f"GIF generated and saved as {gif_filename}")

GIF generated and saved as vir_task.gif


In [40]:
from IPython.display import Image

Image(url='./vir_task.gif', width=250)

### Complex Span (CS Task)

In [41]:
from PIL import Image

In [42]:
cs_task_dataset = Complex_WM_Dataset(data_path='./wm_bench_data/', 
                                               max_seq_len=20, 
                                               rs_img_size=96, 
                                               split='test')

Getting data for Complex_WM Task
Data already exists. Skipping data generation.


In [43]:
trial_num = 3000

# Load the data
cs_task_trial = cs_task_dataset[trial_num]
stim_seq, gt, seq_len, num_distractor, variation = cs_task_trial

# Generate gif of the stimulus sequence
stim_seq_gif = []
for stim in stim_seq:    
    stim_array = stim.numpy()
    stim_array = np.transpose(stim_array, (1, 2, 0))
    stim_seq_gif.append(Image.fromarray((stim_array*255).astype(np.uint8)))

stim_seq_gif = stim_seq_gif[:seq_len]

In [44]:
gif_filename = "cs_task.gif"

# Save the GIF
stim_seq_gif[0].save(
    gif_filename,
    save_all=True,
    append_images=stim_seq_gif[1:],
    loop=0,
    duration=1000,
    disposal=2,
)

print(f"GIF generated and saved as {gif_filename}")

GIF generated and saved as cs_task.gif


In [45]:
from IPython.display import Image

Image(url='./cs_task.gif', width=250)

### Spatial Task Switching (STS Task)

In [46]:
from PIL import Image

In [47]:
sts_task_dataset = Spatial_Task_Switching_Dataset(data_path='./wm_bench_data/', 
                                               max_seq_len=20, 
                                               variant='Cued', 
                                               rs_img_size=96, 
                                               split='test')

Getting data for Spatial_Task_Switching Task
Data already exists. Skipping data generation.


In [48]:
trial_num = 3000

# Load the data
sts_task_trial = sts_task_dataset[trial_num]
stim_seq, gt, seq_len = sts_task_trial

# Generate gif of the stimulus sequence
stim_seq_gif = []
for stim in stim_seq:    
    stim_array = stim.numpy()
    stim_array = np.transpose(stim_array, (1, 2, 0))
    stim_seq_gif.append(Image.fromarray((stim_array*255).astype(np.uint8)))

stim_seq_gif = stim_seq_gif[:seq_len]

In [49]:
gif_filename = "sts_task.gif"

# Save the GIF
stim_seq_gif[0].save(
    gif_filename,
    save_all=True,
    append_images=stim_seq_gif[1:],
    loop=0,
    duration=1000,
    disposal=2,
)

print(f"GIF generated and saved as {gif_filename}")

GIF generated and saved as sts_task.gif


In [50]:
from IPython.display import Image

Image(url='./sts_task.gif', width=250)

### Change Detection Task (CD Task)

In [67]:
from PIL import Image

In [52]:
cd_task_dataset = Color_Orientation_Size_Gap_Conjunction_Change_Detection_Dataset(data_path='./wm_bench_data/', 
                                                                                max_seq_len=20, 
                                                                                variant='Color',
                                                                                rs_img_size=96, 
                                                                                split='test')

Getting data for Color_Orientation_Size_Gap_Conjunction_Change_Detection_Color Task
Data already exists. Skipping data generation.


In [68]:
trial_num = 1600

# Load the data
cd_task_trial = cd_task_dataset[trial_num]
stim_seq, gt, seq_len, ri, set_size = cd_task_trial

# Generate gif of the stimulus sequence
stim_seq_gif = []
for stim in stim_seq:    
    stim_array = stim.numpy()
    stim_array = np.transpose(stim_array, (1, 2, 0))
    stim_seq_gif.append(Image.fromarray((stim_array*255).astype(np.uint8)))

stim_seq_gif = stim_seq_gif[:seq_len]

In [69]:
gif_filename = "cd_task.gif"

# Save the GIF
stim_seq_gif[0].save(
    gif_filename,
    save_all=True,
    append_images=stim_seq_gif[1:],
    loop=0,
    duration=1000,
    disposal=2,
)

print(f"GIF generated and saved as {gif_filename}")

GIF generated and saved as cd_task.gif


In [71]:
from IPython.display import Image

Image(url='./cd_task.gif', width=250)