# Contrastive Learning

In [1]:
# Imports
import torch
import torchvision

from torch.utils.data import Dataset, DataLoader
from torchvision import datasets
from torchvision.transforms import ToTensor
from torchvision.io import read_image
import torchvision.models as models
import torch.nn as nn
from torchvision import transforms, utils
from torchvision.utils import save_image
from tqdm.auto import tqdm

import os
import pandas as pd
import matplotlib.pyplot as plt
import numpy as np
import random

from copy import deepcopy as copy
from time import time
from PIL import Image

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
import random

In [3]:
# Custom libraries
from data_utils import *

In [4]:
# GPU setup
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")

## 1. Data Pipeline

In [5]:
# Paths to the different annotation files.
station_df_paths = [
    '/ourdisk/hpc/ai2es/melreyes/vis/datasets/complete_vis_sets/vis_BATA_2018-06-21_22:05:26_2021-12-30_22:00:27_color_150s/vis_and_images_dataset.csv',
    '/ourdisk/hpc/ai2es/melreyes/vis/datasets/complete_vis_sets/vis_BUFF_2017-03-09_21:15:24_2021-12-30_22:00:27_color_150s/vis_and_images_dataset.csv',
    '/ourdisk/hpc/ai2es/melreyes/vis/datasets/complete_vis_sets/vis_ELMI_2016-06-30_16:55:07_2021-12-30_22:00:30_color_150s/vis_and_images_dataset.csv',
    '/ourdisk/hpc/ai2es/melreyes/vis/datasets/complete_vis_sets/vis_GABR_2016-01-20_20:50:01_2021-12-30_21:45:30_color_150s/vis_and_images_dataset.csv',
    '/ourdisk/hpc/ai2es/melreyes/vis/datasets/complete_vis_sets/vis_GFAL_2017-03-10_18:55:20_2021-12-30_21:50:30_color_150s/vis_and_images_dataset.csv',
    '/ourdisk/hpc/ai2es/melreyes/vis/datasets/complete_vis_sets/vis_JOHN_2017-06-22_14:15:26_2021-12-30_21:35:32_color_150s/vis_and_images_dataset.csv',
    '/ourdisk/hpc/ai2es/melreyes/vis/datasets/complete_vis_sets/vis_MANH_2017-07-18_18:50:33_2021-12-30_21:50:31_color_150s/vis_and_images_dataset.csv',
    '/ourdisk/hpc/ai2es/melreyes/vis/datasets/complete_vis_sets/vis_PENN_2016-05-24_15:05:04_2021-12-30_22:00:34_color_150s/vis_and_images_dataset.csv',
    '/ourdisk/hpc/ai2es/melreyes/vis/datasets/complete_vis_sets/vis_POTS_2018-02-03_15:55:29_2021-12-30_21:35:35_color_150s/vis_and_images_dataset.csv',
    '/ourdisk/hpc/ai2es/melreyes/vis/datasets/complete_vis_sets/vis_QUEE_2017-04-18_18:35:28_2021-12-30_23:00:34_color_150s/vis_and_images_dataset.csv'
]

dataset_root_dir = '/ourdisk/hpc/ai2es/datasets/auto_archive_notyet'

nysm_asos_pairs_df = pd.read_csv('/ourdisk/hpc/ai2es/melreyes/NYSM_data/nysm_asos_station_pairs.csv')

In [6]:
dfs = load_data(df_paths=station_df_paths[0:4], 
                filter_by_suntime=True, 
                pairs_df=nysm_asos_pairs_df)

# train_dfs, valid_dfs, test_dfs = split_data_into_folds(dfs=dfs, 
#                                                        training_years=[2016, 2017, 2018, 2019], 
#                                                        validation_years=[2020], 
#                                                        testing_years=[2021], 
#                                                        rotation=0)

train_dfs, valid_dfs, test_dfs = split_data_into_folds(dfs=dfs, 
                                                       training_years=[2019], 
                                                       validation_years=[2020], 
                                                       testing_years=[2021], 
                                                       rotation=0)

********** Loading Dataframes **********
BATA
	no. examples: (99933, 9)
	Years: [2018 2019 2020 2021]
********** Loading Dataframes **********
BUFF
	no. examples: (128768, 9)
	Years: [2017 2018 2019 2020 2021]
********** Loading Dataframes **********
ELMI
	no. examples: (167768, 9)
	Years: [2016 2017 2018 2019 2020 2021]
********** Loading Dataframes **********
GABR
	no. examples: (143819, 9)
	Years: [2016 2017 2018 2019 2020 2021]


In [7]:
transform = transforms.Compose([
    transforms.Resize((256, 256))
])

In [8]:
from io import BytesIO

class ImageDataset(Dataset):
    
    def __init__(self, df, 
                 root_dir, 
                 image_col, 
                 label_col, 
                 transform=None, 
                 target_transform=None, 
                 cache_size=None, 
                 cache_path=None, 
                 shuffle=False):
        """
        Loads images from a dataframe, and optionally caches them to main memory or local cache, but not both. 

        :param df: pd.DataFrame
        :param transform: transformation to apply to the images
        :param target_transform: transformation to apply to the labels
            
        DEPRECATED
        :param in_memory: bool, load the entire dataset into memory if True
        :param cache_size: maximum number of dataset images to cache to main memory
        
        TODO(mel): add warning for --> raises an error if both are cached
        """
        if shuffle:
            self.df = df.sample(frac=1, random_state=42).reset_index(drop=True)
        
        # storing the order of files in cache
        self.df['cache_index'] = df.columns.get_loc()
        # shuffle the dataframe?  After each epoch?
        self.shuffle = shuffle
        # root directory for the images
        self.root_dir = root_dir
        # column name for the images in the dataframe
        self.image_col = image_col
        # column name for the labels in the dataframe
        self.label_col = label_col
        # transform function for the images
        self.transform = transform
        # transform function for the labels
        self.target_transform = target_transform
        # number of elements to store in the cache
        self.cache_size = cache_size
        # path to the cache file
        self.cache_path = cache_path
        # in-memory cache
        self.cache_dict = None
        
        if self.cache_size:
            self.cache_dict = {}
            
        self.tensorsize = None
        self.labelsize = None
        
        self.fpw = open(cache_path, 'wb')
        self.fpr = open(cache_path, 'rb')
        
    
    def __del__(self):
        self.fpw.close()
        self.fpr.close()

            
    def __len__(self):
        return len(self.df)
    
    
    def shuffle(self):
        self.df = self.df.sample(frac=1).reset_index(drop=True)
        

    def _read_tensors_from_iostream(index):
        # seek to the position of the array
        self.fpr.seek(index * (self.tensorsize + self.labelsize))
        # return read array resized to the correct shape
        tensor = torch.load(BytesIO(fp.read(self.tensorsize)))
        self.fpr.seek(self.tensorsize, 1)
        label = torch.load(BytesIO(fp.read(blocksize)))
        return tensor, label

    def _write_tensors_to_iostream(index, tensor, label):
        # seek to the position of the array
        self.fpw.seek(index * (self.tensorsize + self.labelsize))
        torch.save(tensor, self.fpw)
        torch.save(label, self.fpw)
        self.fpw.flush()
    
    
    def _fetch_from_in_memory_cache(self, idx):
        """ fetch the image you wanted from disk """
        cache_id = self.df.iloc[idx]['cache_index']

        return self.cache_dict[cache_id]
    
    
    def _fetch_from_local_cache(self, idx):
        """ fetch the image you wanted from file on disk """
        tensor, label = self._read_tensors_from_iostream(self.df.iloc[idx]['cache_index'])

        return tensor, label
    
    
    def _fetch_from_disk(self, idx):
        """ fetch the image you wanted from disk """
        image_path = os.path.join(self.root_dir, self.df[self.image_col].iloc[idx])
        image = read_image(image_path)

        label = self.df[self.label_col].iloc[idx]

        if self.transform:
            image = self.transform(image)
        if self.target_transform:
            label = self.target_transform(label)

        return image.to(torch.float32), label
    
    
    def _add_to_local_cache(self, idx):
        
        image, label = self._fetch_from_disk(idx)
        
        self.write_tensors_to_iostream(self.df.iloc[idx]['cache_index'], image, label)
        
        return image, label
    
    
    def _add_to_in_memory_cache(self, idx):
        
        cache_id = self.df.iloc[idx]['cache_index']
        self.cache_dict[cache_id] = self._fetch_from_disk(idx)
            
        return self.cache_dict[idx]

    
    def _fetch_from_cache(self, idx):
        
        if idx in self.cache_dict:
            # If item is in cache, return it. 
            cache_id = self.df[self.image_col].iloc[idx]
            
            if self.cache_path:
                return self._fetch_from_local_cache(cache_id, idx)
            
            return self._fetch_from_in_memory_cache(cache_id, idx)
        
        elif len(self.cache_dict) <= self.cache_size:
            # If cache is not full, add item to cache and return it.
            
            if self.cache_path:
                return self._add_to_local_cache(idx)
            
            return self._add_to_in_memory_cache(idx)
            
        elif len(self.cache_dict) > self.cache_size:
            # If cache is full remove an entry.
            self.cache_dict.pop(random.choice(list(self.cache_dict.keys())))
            return self._fetch_from_disk(idx)
        
        
    def _fetch(self, idx):
        
        # if the idx is an integer
        if self.cache_dict is not None:
            return self._fetch_from_cache(idx)
        else:
            # otherwise just fetch and return from disk
            return self._fetch_from_disk(idx)

        
    def _normalize_image(self, image):
        return (image - 127.5) / 128

    
    def __getitem__(self, idx):
        
        if self.tensorsize is None:
            bytesio = BytesIO()
            tensor, label = self[0]
            torch.save(tensor, bytesio)
            self.tensorsize = bytesio.getbuffer().nbytes # size is given in bytes
            torch.save(label, bytesio)
            self.labelsize = bytesio.getbuffer().nbytes - self.tensorsize # size is given in bytes
        
        if isinstance(idx, int):
            
            image, label = self._fetch(idx)
            # TODO: remove, this is best handled by a transform
            return self._normalize_image(image), label
        
        elif isinstance(idx, slice): 
            
            return ImageDataset(self.df.iloc[idx], self.root_dir, self.image_col, self.label_col, 
                 transform=self.transform, target_transform=self.target_transform, 
                 cache_size=self.cache_size, cache_path=self.cache_path)
        
        else:
            raise ValueError(f'dataset index must be integer or slice found: {type(idx)}')
            
        
            

In [9]:
class ContrastiveDataset(Dataset):
    
    def __init__(self, datasets, n_images):
        
        if len(datasets) < n_images:
            raise ValueError('...')
        
        self.datasets = datasets
        self.n_images = n_images
        
    def __len__(self):
        return max([ len(ds) for ds in self.datasets ])
    
    def __reshape_batch__(self, batch):
        
        images = [ torch.tensor(sample['image']) for sample in batch ]
        labels = [ torch.tensor(sample['label']) for sample in batch ]
        
        return images, labels
    
    def __getitem__(self, idx):
        
        batch_start = (idx * self.n_images) % len(self.datasets)
        
        # batch = [ self.datasets[i % len(self.datasets)][idx % len(self.datasets[i % len(self.datasets)])] for i in np.arange(batch_start, batch_start + self.n_images) ]

        
        batch_images = []
        batch_labels = []
        
        for i in np.arange(batch_start, batch_start + self.n_images):
    
            ds = self.datasets[i % len(self.datasets)]
        
            image, label = ds[idx % len(ds)]
            
            batch_images.append(image)
            batch_labels.append(label)
    
    
        return {'images': torch.stack(batch_images), 'labels': torch.tensor(batch_labels)}

In [78]:
a = [3,4,5]
random.shuffle(a)
a

[3, 4, 5]

In [76]:
# def SampleDataset(Dataset):
    
#     def __init__(self, datasets):
        
#         self.datasets = datasets

#     def __len__(self):
#         return min([len(ds) for ds in self.datasets])
    
#     def __getitem__(self, idx):
        
#         ds = self.datasets[torch.randint(len(self.datasets), (1,))[0]]
        
#         return ds[idx % len(ds)], ds[idx + 1 % len(ds)]
        

In [53]:
def create_station_dataset(df):

    datasets = []

    for vis_reading in df['visibility (mi)'].unique():
        vis_df = df[df['visibility (mi)'] == vis_reading]

        ds = ImageDataset(vis_df, dataset_root_dir, 'image_path', 'visibility (mi)', 
                         transform=transform)

        datasets.append(ds)
        
    return ContrastiveDataset(datasets, 10)

In [79]:
def create_class_datasets(df):
    
    datasets = []

    for vis_reading in df['visibility (mi)'].unique():
        vis_df = df[df['visibility (mi)'] == vis_reading]

        ds = ImageDataset(vis_df, dataset_root_dir, 'image_path', 'visibility (mi)', 
                         transform=transform)

        datasets.append(ds)
        
    return datasets

In [80]:
class_datasets = create_class_datasets(train_dfs[0])

In [151]:
# len(class_datasets[-1])

In [162]:
class_train_loaders = [DataLoader(dataset, batch_size=16, shuffle=True, worker_init_fn = lambda id: np.random.seed(int(time)), drop_last=True) 
                       for dataset in class_datasets]


In [153]:
# class SamplingDataset(Dataset):
    
#     def __init__(self, datasets, iter_approach='min'):
        
#         self.datasets = datasets
        
#         self.datasets_markers = [0] * len(datasets)
        
#         self.iter_count_per_dataset = self._calculate_steps_per_iteration(iter_approach)
        
#         self.round_robin = np.arange(0, len(datasets))
#         # np.random.shuffle(self.round_robin)
        
#     def _calculate_steps_per_iteration(self, iter_approach):
        
#         counts = [ len(ds) for ds in self.datasets ]
#         counts.sort()
        
#         if iter_approach == 'min':
#             return min(counts)
#         elif iter_approach == 'max':
#             return max(counts)

#         raise ValueError('iter_approach must be either \"min\" or \"max\". ')
            
    
#     def __len__(self):
#         return self.iter_count_per_dataset * len(self.datasets)
        
        
#     def __getitem__(self, idx):
        
#         # if idx % len(self.datasets) == 0:
#         #     np.random.shuffle(self.round_robin)
            
#         ds_idx = self.round_robin[idx % len(self.datasets)]
        
#         image, label = self.datasets[ds_idx][self.datasets_markers[ds_idx]]
        
#         self.datasets_markers[ds_idx] = (self.datasets_markers[ds_idx] + 1) % len(self.datasets[ds_idx])
            
#         return image, label

## 2. Model Build

In [154]:
from torch.nn import Module
from torch.nn import Conv2d
from torch.nn import Linear
from torch.nn import MaxPool2d
from torch.nn import ReLU
from torch.nn import LogSoftmax
from torch import flatten
import torch.nn.functional as F

In [170]:
class CNet(Module):
    def __init__(self, numChannels=3):
        super(CNet, self).__init__()
        
        self.conv1 = Conv2d(in_channels=numChannels, out_channels=16, kernel_size=(5, 5))
        self.relu1 = ReLU()
        self.maxpool1 = MaxPool2d(kernel_size=(2, 2), stride=(2, 2))
        
        self.conv2 = Conv2d(in_channels=16, out_channels=32, kernel_size=(5, 5))
        self.relu2 = ReLU()
        self.maxpool2 = MaxPool2d(kernel_size=(2, 2), stride=(2, 2))
        
    def cnn_block(self, x):
        x = self.conv1(x)
        x = self.relu1(x)
        x = self.maxpool1(x)

        x = self.conv2(x)
        x = self.relu2(x)
        x = self.maxpool2(x)
        
        x, _  = torch.max(x, 1)
        x, _  = torch.max(x, 1)
        
        return x

    def forward(self, image):
        return self.cnn_block(image)

In [171]:
model = CNet()

In [172]:
for params in model.parameters():
    params.requires_grad = True
        
if torch.cuda.is_available():
    model.cuda()

## 3. Training

In [173]:
model.train()
print('Training')
train_running_loss = 0.0
train_running_correct = 0
counter = 0

dataloader_iterators = [ iter(train_loader) for train_loader in class_train_loaders]

Training


In [191]:
optimizer = torch.optim.Adam(model.parameters(), lr=5e-4)
criterion = torch.nn.CrossEntropyLoss()

for i in range(1000):
    print(i)
    
    data_dict = {}

    for cl_idx, dataloader_iterator in enumerate(dataloader_iterators):

        try:
            data_dict[cl_idx] = next(dataloader_iterator)
        except StopIteration:
            # dataloader_iterator = iter(class_train_loaders[cl_idx])
            # data_dict[cl_idx] = next(dataloader_iterator)
            pass

    pred_dict = {}

    for cl_idx in list(data_dict.keys()):
        data = data_dict[cl_idx]

        pred_dict[cl_idx] = model(data[0].to(device))

    preds = torch.stack(list(pred_dict.values()))
    vectors = torch.nn.functional.normalize(preds, dim=-1)

    vector_a = vectors.permute(1, 0, 2)
    vector_b = vectors.permute(1, 2, 0)

    outputs = torch.matmul(vector_a, vector_b)

    labels = torch.eye(len(data_dict)).repeat(16, 1, 1).to(device)

    scores1 = torch.nn.Softmax(-1)(outputs)
    scores2 = torch.nn.Softmax(-2)(outputs)

    loss1 = torch.mean(torch.negative(torch.sum(torch.log(scores1) * labels, -1)))
    loss2 = torch.mean(torch.negative(torch.sum(torch.log(scores2) * labels, -2)))
    loss = loss1 + loss2

    train_running_loss += loss.item()
    # Calculate the accuracy.
    _, preds = torch.max(outputs.data, 1)
    # train_running_correct += (preds == labels).sum().item()
    # Backpropagation
    optimizer.zero_grad()
    loss.backward()
    # Update the weights.
    optimizer.step()
    

0
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
27

In [192]:
torch.argmax(outputs, dim=1)

tensor([[0],
        [0],
        [0],
        [0],
        [0],
        [0],
        [0],
        [0],
        [0],
        [0],
        [0],
        [0],
        [0],
        [0],
        [0],
        [0]], device='cuda:0')

## 4. Evaluation