In [1]:
import numpy as np

import torch
import torch.nn as nn
from torch.utils.data import DataLoader, Dataset

import datasets

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
dataset=datasets.load_dataset("ccdv/arxiv-summarization",split='train',streaming=True)

You can avoid this message in future by passing the argument `trust_remote_code=True`.
Passing `trust_remote_code=True` will be mandatory to load this dataset from the next major release of `datasets`.


In [3]:
raw_dataset=list(dataset.take(3500))

# BATCH SIZE: 4 (papers)
# CHUNK SIZE: 5 (each paper broken into 5 chunks of n tokens each)




#        forward pass 1 | FP 2    | FP 3    | FP 4    | FP 5    |
#
# paper 1:      chunk 1 | chunk 2 | chunk 3 | chunk 4 | chunk 5 |
# paper 2:      chunk 1 | chunk 2 | chunk 3 | chunk 4 | chunk 5 |
# paper 3:      chunk 1 | chunk 2 | chunk 3 | chunk 4 | chunk 5 |
# paper 4:      chunk 1 | chunk 2 | chunk 3 | chunk 4 | chunk 5 |
#
#
#
#        forward pass 6 | FP 7    | FP 8    | FP 9    | FP 10   |
#
# paper 5:      chunk 1 | chunk 2 | chunk 3 | chunk 4 | chunk 5 |
# paper 6:      chunk 1 | chunk 2 | chunk 3 | chunk 4 | chunk 5 |
# paper 7:      chunk 1 | chunk 2 | chunk 3 | chunk 4 | chunk 5 |
# paper 8:      chunk 1 | chunk 2 | chunk 3 | chunk 4 | chunk 5 |


In [4]:
segments =10
segment_lenght=512

chunk_size = segments * segment_lenght

In [5]:
chunk_size

5120

In [6]:
raw_articles=[x['article'] for x in raw_dataset]

In [7]:
raw_articles= [x for x  in raw_articles if len(x)>5120]

In [8]:
print("number of articles", len(raw_articles))

number of articles 3401


In [9]:
def decode_text(tokens):
    return ''.join([chr(i) for i in tokens])

In [10]:
#decode_text(np.frombuffer(raw_articles[0],dtype=np.uint8)[:512])

In [11]:
unique_chars=set(''.join([i for i in raw_articles]))
print("character set length", len(unique_chars))
print("character set", ''.join(sorted(unique_chars)))

character set length 70
character set 
 !"#$%&'()*+,-./0123456789:;<=>?@[\]^_`abcdefghijklmnopqrstuvwxyz{|}~


In [12]:
converted=[np.fromstring(doc, dtype=np.uint8) for doc in raw_articles]

  converted=[np.fromstring(doc, dtype=np.uint8) for doc in raw_articles]


In [13]:
def clip_article(doc, chunk_size):
    remainder=len(doc)%chunk_size
    return doc[:-remainder]

In [14]:
clipped=[clip_article(doc, 5120) for doc in converted]

In [15]:
clipped[2].shape[0]/5120

5.0

In [16]:
shapes = set([doc.shape for doc in clipped])

In [17]:
shapes

{(5120,),
 (10240,),
 (15360,),
 (20480,),
 (25600,),
 (30720,),
 (35840,),
 (40960,),
 (46080,),
 (51200,),
 (56320,),
 (61440,),
 (66560,),
 (71680,),
 (76800,),
 (81920,),
 (87040,),
 (92160,),
 (97280,),
 (102400,),
 (107520,),
 (112640,),
 (117760,),
 (122880,),
 (128000,),
 (133120,),
 (138240,),
 (143360,),
 (148480,),
 (153600,),
 (158720,),
 (163840,),
 (184320,),
 (189440,),
 (194560,),
 (204800,),
 (209920,),
 (220160,),
 (225280,),
 (230400,),
 (245760,)}

In [18]:
chunked=[doc.reshape(-1,chunk_size) for doc in clipped]

In [19]:
processed_data=torch.tensor(np.concatenate(chunked),dtype=torch.long)

In [20]:
processed_data.shape

torch.Size([20853, 5120])

In [21]:
loader=iter(DataLoader(processed_data, batch_size=8, shuffle=True))

In [22]:
example=next(loader)

In [23]:
example.shape

torch.Size([8, 5120])

In [24]:
seq, labels=example[:,:-1],example[:,1:]

In [25]:
seq[0][:15]

tensor([110,  32, 102, 111, 108, 108, 111, 119, 115,  32, 115, 105, 109, 105,
        108])

In [26]:
labels.shape

torch.Size([8, 5119])

In [27]:
for seq_segment, labels_segment in zip(seq.chunk(10,dim=-1),labels.chunk(10,dim=-1)):
    print(decode_text(seq_segment[0]),"\n**********\n")

n follows similarly to ( [ floctypeq ] ) . 
 + the statement follows similarly when @xmath182 holds by @xmath231 . 
 ( derivation by @xmath238 ) we have then that @xmath240 and thus @xmath311 . 
 [ lem : termintromonotone ] if @xmath207 and @xmath74 then @xmath328 . 
 let @xmath207 and @xmath74 . from the definition we have that @xmath182 . 
 from lemma [ lem : typeintromonotone ] we have that @xmath293 . by induction on the derivation of @xmath182 .    1 . 
 let @xmath182 by @xmath218 . 
 by induction on t 
**********

he derivation of @xmath207 
 let @xmath329 . 
 then @xmath330 and @xmath328 . 
 2 .   let @xmath205{\!:\!}a , k\notin{\textup{dom(}p\textup{)}}$ ] and @xmath331 . if @xmath150 then @xmath332 , since the derivation of @xmath333 is strictly smaller than the derivation @xmath207 , by ih , @xmath328 . otherwise , @xmath334 and @xmath335 { \!:\!}a$ ] . 
 but @xmath336 and by ih @xmath337 . by the definition @xmath328 . 
 + the statement follows similarly when @xmath182 is de

In [28]:
model=nn.Sequential(
    nn.Embedding(128,16), #(vocab_size, embedding_dim)
    nn.Linear(16,150),
    nn.ReLU(),
    nn.Linear(150,150),
    nn.ReLU(),
    nn.Linear(150,128), #(params, vocab_size)
)

In [29]:
loss_fn=nn.CrossEntropyLoss()

In [30]:
optimizer=torch.optim.SGD(model.parameters(),lr=0.05)

In [31]:
model.train()

Sequential(
  (0): Embedding(128, 16)
  (1): Linear(in_features=16, out_features=150, bias=True)
  (2): ReLU()
  (3): Linear(in_features=150, out_features=150, bias=True)
  (4): ReLU()
  (5): Linear(in_features=150, out_features=128, bias=True)
)

In [32]:
segments=10

In [34]:
for i in range(200):
    
    data=next(loader) #(batch_size, sequence_length) #(8,5120)
    seq,labels= data[:,:-1], data[:,1:]
    train_loss=0.
    
    for seq_segment, labels_segment in zip(seq.chunk(segments,dim=-1),labels.chunk(segments, dim=-1)):#ten passes of (8, 512)
        optimizer.zero_grad()
        y_pred=model(seq_segment)
        #print(y_pred.shape)
        y_pred=y_pred.transpose(2,1)
        loss=loss_fn(y_pred,labels_segment)
        loss.backward()
        optimizer.step()
        train_loss +=loss.item()
        
        
    if i%5==0:
        print(train_loss/segments)

        

4.55968279838562
3.3730733394622803
3.046580100059509
2.939957523345947
2.9663415431976317
2.83984375
2.8321441888809202
2.7357513666152955
2.6825407981872558
2.609228801727295
2.621779370307922
2.7563969612121584
2.635390543937683
2.5794907808303833
2.5554212808609007
2.5501531600952148
2.5392120122909545
2.481417107582092
2.474931073188782
2.4647560596466063
2.477619981765747
2.4739154100418093
2.5278233528137206
2.4088573932647703
2.473329019546509
2.450107789039612
2.5675188064575196
2.6088327407836913
2.454421043395996
2.4861522197723387
2.4466771125793456
2.455324983596802
2.491788864135742
2.481079912185669
2.4183900594711303
2.4627229452133177
2.602856183052063
2.53012101650238
2.4814670085906982
2.4994054794311524
