-
Notifications
You must be signed in to change notification settings - Fork 4
/
news_loader.py
33 lines (26 loc) · 1 KB
/
news_loader.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
import pandas as pd
from torch.utils.data import Dataset
def read_articles(article):
special_symbols = {'J.': '$JTrumph$', 'Dr.':'$dr$', 'Mr.':'$mr$', 'Mrs.':'$mrs$'}
text = []
temp = article.replace('J.', '$JTrumph$')
temp = temp.replace('Dr.', '$dr$')
temp = temp.replace('Mr.', '$mr$')
temp = temp.replace('Mrs.', '$mrs$')
temp = temp.replace('U.S.', '$us$')
temp = temp.replace(". ", ".\n").split("\n")
for sent in temp:
sent_ = sent.replace( '$JTrumph$', 'J.')
sent_ = sent_.replace( '$dr$', 'Dr.')
sent_ = sent_.replace( '$mr$','Mr.')
sent_ = sent_.replace('$mrs$', 'Mrs.')
sent_ = sent_.replace('$us$', 'U.S.')
text.append(sent_)
return text
class NewsDataset(Dataset):
def __init__(self, dataset_path):
self.df = pd.read_csv(dataset_path)
def __len__(self):
return len(self.df)
def __getitem__(self,idx):
return read_articles(self.df['content'][idx])