# Importation

In [2]:
import torch
import torch.optim as optim
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader

import os
import sys
import random 
import pickle

import models

from tqdm.notebook import tqdm
from multiprocessing import Pool

from torch_geometric.nn import summary

# Paramètres

In [3]:
# setup parameters

SEED = 1234
DATA_DIR = 'data'
DATASET_PATH = 'java-small-preprocessed-code2vec/java-small'
DATASET_NAME = 'java-small'
EMBEDDING_DIM = 128
DROPOUT = 0.25
BATCH_SIZE = 128
CHUNKS = 10
MAX_LENGTH = 200
LOG_EVERY = 1000 #print log of results after every LOG_EVERY batches
N_EPOCHS = 20
LOG_DIR = 'logs'
SAVE_DIR = 'checkpoints'
LOG_PATH = os.path.join(LOG_DIR, f'{DATASET_NAME}-log.txt')
MODEL_SAVE_PATH = os.path.join(SAVE_DIR, f'{DATASET_NAME}-model.pt')
LOAD = False #set true if you want to load model from MODEL_SAVE_PATH

device = torch.device('cuda')

## Log func

In [4]:
def logfunc(log):
    with open(LOG_PATH, 'a+') as f:
        f.write(log+'\n')
    print(log)

## Dir init

In [5]:
if not os.path.isdir(f'{SAVE_DIR}'):
    os.makedirs(f'{SAVE_DIR}')

if not os.path.isdir(f'{LOG_DIR}'):
    os.makedirs(f'{LOG_DIR}')

""" if os.path.exists(LOG_PATH):
    os.remove(LOG_PATH) """

' if os.path.exists(LOG_PATH):\n    os.remove(LOG_PATH) '

# Seed fixing

In [6]:
random.seed(SEED)
torch.manual_seed(SEED)
torch.cuda.manual_seed(SEED)
# torch.backends.cudnn.deterministic = True

# Chargement des données

## Dict des word (variables), path, target

In [7]:
with open(f'{DATA_DIR}/{DATASET_PATH}/{DATASET_NAME}.dict.c2v', 'rb') as file:
    word2count = pickle.load(file)
    path2count = pickle.load(file)
    target2count = pickle.load(file)
    n_training_examples = pickle.load(file)

# create vocabularies, initialized with unk and pad tokens

word2idx = {'<unk>': 0, '<pad>': 1}
path2idx = {'<unk>': 0, '<pad>': 1 }
target2idx = {'<unk>': 0, '<pad>': 1}

for w in word2count.keys():
    word2idx[w] = len(word2idx)

for p in path2count.keys():
    path2idx[p] = len(path2idx)

for t in target2count.keys():
    target2idx[t] = len(target2idx)

idx2word = {v: k for k, v in word2idx.items()}
idx2path = {v: k for k, v in path2idx.items()}
idx2target = {v: k for k, v in target2idx.items()}

In [8]:
del pickle

In [9]:
len(idx2target), len(idx2word), len(idx2path)

(199749, 507272, 807139)

## File Reading

In [10]:
def load_data(file_path):
    with open(file_path, 'r') as f:
        return [
            (line.split(' ')[0], [t.split(',') for t in line.split(' ')[1:] if t.strip()])
            for line in f if len(line.split(' ')) - 1 <= MAX_LENGTH
        ]

In [None]:
def load_data(file_path):
    data = []
    
    with open(file_path, 'r') as f:
        for line in tqdm(f):
            parts = line.strip().split(' ')
            if len(parts) - 1 > MAX_LENGTH:
                continue
            
            name = target2idx.get(parts[0], target2idx['<unk>'])
            
            path_contexts = [tuple(t.split(',')) for t in parts[1:] if t.strip()]
            left, path, right = zip(*path_contexts) if path_contexts else ([], [], [])
            
            left_tensor = torch.tensor([word2idx.get(l, word2idx['<unk>']) for l in left], dtype=torch.long)
            path_tensor = torch.tensor([path2idx.get(p, path2idx['<unk>']) for p in path], dtype=torch.long)
            right_tensor = torch.tensor([word2idx.get(r, word2idx['<unk>']) for r in right], dtype=torch.long)

            data.append((torch.tensor(name, dtype=torch.long), left_tensor, path_tensor, right_tensor))
    
    return data

In [12]:
data_test = load_data(f'{DATA_DIR}/{DATASET_PATH}/{DATASET_NAME}.test.c2v')

In [13]:
data_val = load_data(f'{DATA_DIR}/{DATASET_PATH}/{DATASET_NAME}.val.c2v')

In [14]:
data_train = load_data(f'{DATA_DIR}/{DATASET_PATH}/{DATASET_NAME}.train.c2v')

In [15]:
len(data_test), len(data_val), len(data_train)

(56165, 23505, 665115)

In [16]:
n_training_examples

665115

## Data Loader

In [17]:
def collate_fn(samples):
    name_idx = torch.stack([e[0] for e in samples])
    
    max_length = max(len(e[1]) for e in samples)
    
    def pad_tensor(tensor_list, pad_value):
        return torch.stack([torch.cat([t, torch.full((max_length - len(t),), pad_value)]) for t in tensor_list])

    left_tensor = pad_tensor([e[1] for e in samples], word2idx['<pad>'])
    path_tensor = pad_tensor([e[2] for e in samples], path2idx['<pad>'])
    right_tensor = pad_tensor([e[3] for e in samples], word2idx['<pad>'])

    return name_idx, left_tensor, path_tensor, right_tensor


In [18]:
train_loader = DataLoader(data_train, batch_size=BATCH_SIZE, collate_fn=collate_fn,
                          pin_memory=True, shuffle=True, num_workers=0, prefetch_factor=None)
test_loader = DataLoader(data_test, batch_size=BATCH_SIZE, collate_fn=collate_fn, 
                         pin_memory=True, shuffle=False, num_workers=0, prefetch_factor=None)
eval_loader = DataLoader(data_val, batch_size=BATCH_SIZE, collate_fn=collate_fn, 
                         pin_memory=True, shuffle=False, num_workers=0, prefetch_factor=None)

In [19]:
len(train_loader), len(test_loader), len(eval_loader)

(5197, 439, 184)

In [20]:
c = [0 for i in range(4)]
for ts in tqdm(train_loader):
    for j, t in enumerate(ts):
        c[j] += t.eq(0).sum().item()
print(c)

  0%|          | 0/5197 [00:00<?, ?it/s]

[0, 0, 625202, 0]


In [21]:
del c, ts

In [22]:
m, Ma = sys.maxsize, 0
for v in path2count.values():
    m, Ma = min(m,v), max(Ma,v)
print((m, Ma))
del m, Ma

(5, 852233)


In [23]:
train_loader.desc = "train"
test_loader.desc = "test"
eval_loader.desc = "eval"

In [24]:
%whos


Variable              Type              Data/Info
-------------------------------------------------
BATCH_SIZE            int               128
CHUNKS                int               10
DATASET_NAME          str               java-small
DATASET_PATH          str               java-small-preprocessed-code2vec/java-small
DATA_DIR              str               data
DROPOUT               float             0.25
DataLoader            type              <class 'torch.utils.data.dataloader.DataLoader'>
EMBEDDING_DIM         int               128
F                     module            <module 'torch.nn.functio<...>orch\\nn\\functional.py'>
LOAD                  bool              False
LOG_DIR               str               logs
LOG_EVERY             int               1000
LOG_PATH              str               logs\java-small-log.txt
MAX_LENGTH            int               200
MODEL_SAVE_PATH       str               checkpoints\java-small-model.pt
N_EPOCHS              int               20


# Instanciation

In [25]:
model = models.Code2Vec(
    nodes_dim=      len(word2idx),      # nb de "var"
    paths_dim=      len(path2idx),      # nb de path
    embedding_dim=  EMBEDDING_DIM,      # à découpé
    output_dim=     len(target2idx),    # nb de classe
    dropout=        DROPOUT).to(device)

if LOAD:
    logfunc(f'Loading model from {MODEL_SAVE_PATH}')
    model.load_state_dict(torch.load(MODEL_SAVE_PATH))

optimizer = optim.Adam(model.parameters(), lr=0.001)

criterion = nn.CrossEntropyLoss().to(device)

## Overview

In [26]:
logfunc(f"Model structure: {model}\n\n")

Model structure: Code2Vec(
  (node_embedding): Embedding(507272, 128)
  (path_embedding): Embedding(807139, 128)
  (out): Linear(in_features=128, out_features=199749, bias=False)
  (do): Dropout(p=0.25, inplace=False)
)




In [27]:
for i in train_loader:
    a=i
    break
logfunc(summary(model, *[b.to(device) for b in a][1:]))
logfunc("\n")

+-----------------------------+------------------------------------+-----------------+-------------+
| Layer                       | Input Shape                        | Output Shape    | #Param      |
|-----------------------------+------------------------------------+-----------------+-------------|
| Code2Vec                    | [128, 200], [128, 200], [128, 200] | [128, 199749]   | 193,861,760 |
| ├─(node_embedding)Embedding | [128, 200]                         | [128, 200, 128] | 64,930,816  |
| ├─(path_embedding)Embedding | [128, 200]                         | [128, 200, 128] | 103,313,792 |
| ├─(out)Linear               | [128, 128]                         | [128, 199749]   | 25,567,872  |
| ├─(do)Dropout               | [128, 200, 384]                    | [128, 200, 384] | --          |
+-----------------------------+------------------------------------+-----------------+-------------+




## Profiling

In [40]:
from torch.profiler import profile, ProfilerActivity

with profile(activities=[ProfilerActivity.CPU, ProfilerActivity.CUDA], 
             record_shapes=True, profile_memory=True) as prof:
    
    for i, (tensor_n, tensor_l, tensor_p, tensor_r) in enumerate(tqdm(eval_loader, desc="eval for {eval_loader.desc} batch", position=1)) :
        if i >= 5:
            break
        # Move tensors to GPU
        tensor_n = tensor_n.to(device, non_blocking=True)
        tensor_l = tensor_l.to(device, non_blocking=True)
        tensor_p = tensor_p.to(device, non_blocking=True)
        tensor_r = tensor_r.to(device, non_blocking=True)
        torch.cuda.synchronize()
        
        optimizer.zero_grad()

        fx = model(tensor_l, tensor_p, tensor_r)
        loss = criterion(fx, tensor_n)

        loss.backward()
        optimizer.step()
        prof.step()

print(prof.key_averages().table(sort_by="cuda_time_total", row_limit=-1))
prof.export_chrome_trace("trace.json")

eval for {eval_loader.desc} batch:   0%|          | 0/184 [00:00<?, ?it/s]

-------------------------------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  
                                                   Name    Self CPU %      Self CPU   CPU total %     CPU total  CPU time avg     Self CUDA   Self CUDA %    CUDA total  CUDA time avg       CPU Mem  Self CPU Mem      CUDA Mem  Self CUDA Mem    # of Calls  
-------------------------------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  
                               Optimizer.step#Adam.step         1.26%       3.985ms         2.95%       9.337ms       1.867ms     386.000us         0.01%        2.286s     457.260ms           0 b         -20 b           0 b      -3.61 G

In [None]:
prof.

In [28]:
[i.shape for i in a]

[torch.Size([128]),
 torch.Size([128, 200]),
 torch.Size([128, 200]),
 torch.Size([128, 200])]