Skip to content

Commit

Permalink
formatting with black and prettier
Browse files Browse the repository at this point in the history
  • Loading branch information
ahmedbesbes committed Apr 28, 2021
1 parent ba37aa4 commit 5931976
Show file tree
Hide file tree
Showing 8 changed files with 486 additions and 432 deletions.
117 changes: 61 additions & 56 deletions clr_parameters_finder.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
'''
"""
This script allows to find the optimal parameters for a learning rate scheduling:
- min_lr
Expand All @@ -20,7 +20,7 @@
reference: https://towardsdatascience.com/adaptive-and-cyclical-learning-rates-using-pytorch-2bf904d18dee
'''
"""

import math
import os
Expand Down Expand Up @@ -51,17 +51,21 @@ def run(args):

batch_size = args.batch_size

training_params = {"batch_size": batch_size,
"shuffle": True,
"num_workers": args.workers}
training_params = {
"batch_size": batch_size,
"shuffle": True,
"num_workers": args.workers,
}

texts, labels, number_of_classes, sample_weights = load_data(args)
train_texts, _, train_labels, _, _, _ = train_test_split(texts,
labels,
sample_weights,
test_size=args.validation_split,
random_state=42,
stratify=labels)
train_texts, _, train_labels, _, _, _ = train_test_split(
texts,
labels,
sample_weights,
test_size=args.validation_split,
random_state=42,
stratify=labels,
)

training_set = MyDataset(train_texts, train_labels, args)
training_generator = DataLoader(training_set, **training_params)
Expand All @@ -74,31 +78,31 @@ def run(args):

criterion = nn.CrossEntropyLoss()

if args.optimizer == 'sgd':
optimizer = torch.optim.SGD(
model.parameters(), lr=args.start_lr, momentum=0.9
)
elif args.optimizer == 'adam':
optimizer = torch.optim.Adam(
model.parameters(), lr=args.start_lr
)
if args.optimizer == "sgd":
optimizer = torch.optim.SGD(model.parameters(), lr=args.start_lr, momentum=0.9)
elif args.optimizer == "adam":
optimizer = torch.optim.Adam(model.parameters(), lr=args.start_lr)

start_lr = args.start_lr
end_lr = args.end_lr
lr_find_epochs = args.epochs
smoothing = args.smoothing

def lr_lambda(x): return math.exp(
x * math.log(end_lr / start_lr) / (lr_find_epochs * len(training_generator)))
def lr_lambda(x):
return math.exp(
x * math.log(end_lr / start_lr) / (lr_find_epochs * len(training_generator))
)

scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer, lr_lambda)

losses = []
learning_rates = []

for epoch in range(lr_find_epochs):
print(f'[epoch {epoch + 1} / {lr_find_epochs}]')
progress_bar = tqdm(enumerate(training_generator),
total=len(training_generator))
print(f"[epoch {epoch + 1} / {lr_find_epochs}]")
progress_bar = tqdm(
enumerate(training_generator), total=len(training_generator)
)
for iter, batch in progress_bar:
features, labels = batch
if torch.cuda.is_available():
Expand All @@ -124,41 +128,42 @@ def lr_lambda(x): return math.exp(
losses.append(loss)

plt.semilogx(learning_rates, losses)
plt.savefig('./plots/losses_vs_lr.png')
plt.savefig("./plots/losses_vs_lr.png")


if __name__ == "__main__":
parser = argparse.ArgumentParser(
'Character Based CNN for text classification')
parser.add_argument('--data_path', type=str,
default='./data/train.csv')
parser.add_argument('--validation_split', type=float, default=0.2)
parser.add_argument('--label_column', type=str, default='Sentiment')
parser.add_argument('--text_column', type=str, default='SentimentText')
parser.add_argument('--max_rows', type=int, default=None)
parser.add_argument('--chunksize', type=int, default=50000)
parser.add_argument('--encoding', type=str, default='utf-8')
parser.add_argument('--sep', type=str, default=',')
parser.add_argument('--steps', nargs='+', default=['lower'])
parser.add_argument('--group_labels', type=str,
default=None, choices=[None, 'binarize'])
parser.add_argument('--ratio', type=float, default=1)

parser.add_argument('--alphabet', type=str,
default='abcdefghijklmnopqrstuvwxyz0123456789-,;.!?:\'"\\/|_@#$%^&*~`+-=<>()[]{}')
parser.add_argument('--number_of_characters', type=int, default=69)
parser.add_argument('--extra_characters', type=str, default='')
parser.add_argument('--max_length', type=int, default=150)
parser.add_argument('--batch_size', type=int, default=128)
parser.add_argument('--optimizer', type=str,
choices=['adam', 'sgd'], default='sgd')
parser.add_argument('--learning_rate', type=float, default=0.01)
parser.add_argument('--workers', type=int, default=1)

parser.add_argument('--start_lr', type=float, default=1e-5)
parser.add_argument('--end_lr', type=float, default=1e-2)
parser.add_argument('--smoothing', type=float, default=0.05)
parser.add_argument('--epochs', type=int, default=1)
parser = argparse.ArgumentParser("Character Based CNN for text classification")
parser.add_argument("--data_path", type=str, default="./data/train.csv")
parser.add_argument("--validation_split", type=float, default=0.2)
parser.add_argument("--label_column", type=str, default="Sentiment")
parser.add_argument("--text_column", type=str, default="SentimentText")
parser.add_argument("--max_rows", type=int, default=None)
parser.add_argument("--chunksize", type=int, default=50000)
parser.add_argument("--encoding", type=str, default="utf-8")
parser.add_argument("--sep", type=str, default=",")
parser.add_argument("--steps", nargs="+", default=["lower"])
parser.add_argument(
"--group_labels", type=str, default=None, choices=[None, "binarize"]
)
parser.add_argument("--ratio", type=float, default=1)

parser.add_argument(
"--alphabet",
type=str,
default="abcdefghijklmnopqrstuvwxyz0123456789-,;.!?:'\"\\/|_@#$%^&*~`+-=<>()[]{}",
)
parser.add_argument("--number_of_characters", type=int, default=69)
parser.add_argument("--extra_characters", type=str, default="")
parser.add_argument("--max_length", type=int, default=150)
parser.add_argument("--batch_size", type=int, default=128)
parser.add_argument("--optimizer", type=str, choices=["adam", "sgd"], default="sgd")
parser.add_argument("--learning_rate", type=float, default=0.01)
parser.add_argument("--workers", type=int, default=1)

parser.add_argument("--start_lr", type=float, default=1e-5)
parser.add_argument("--end_lr", type=float, default=1e-2)
parser.add_argument("--smoothing", type=float, default=0.05)
parser.add_argument("--epochs", type=int, default=1)

args = parser.parse_args()
run(args)
114 changes: 46 additions & 68 deletions config.json
Original file line number Diff line number Diff line change
@@ -1,71 +1,49 @@
{
"alphabet": {
"en": {
"lower": {
"alphabet": "abcdefghijklmnopqrstuvwxyz0123456789-,;.!?:'\"/\\|_@#$%^&*~`+-=<>()[]{}",
"number_of_characters": 69
},
"both": {
"alphabet": "ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789-,;.!?:'\"/\\|_@#$%^&*~`+-=<>()[]{}",
"number_of_characters": 95
}
}
},
"alphabet": {
"en": {
"lower": {
"alphabet": "abcdefghijklmnopqrstuvwxyz0123456789-,;.!?:'\"/\\|_@#$%^&*~`+-=<>()[]{}",
"number_of_characters": 69
},
"both": {
"alphabet": "ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789-,;.!?:'\"/\\|_@#$%^&*~`+-=<>()[]{}",
"number_of_characters": 95
}
}
},

"model_parameters": {
"small": {
"conv": [
[
256,
7,
3
],
[
256,
7,
3
],
[
256,
3,
-1
],
[
256,
3,
-1
],
[
256,
3,
-1
],
[
256,
3,
3
]
],
"fc": [
1024,
1024
]
}
},
"data": {
"text_column": "SentimentText",
"label_column": "Sentiment",
"max_length": 150,
"num_of_classes": 2,
"encoding": null,
"chunksize": 50000,
"max_rows": 100000,
"preprocessing_steps": ["lower", "remove_hashtags", "remove_urls", "remove_user_mentions"]
},
"training": {
"batch_size": 128,
"learning_rate": 0.01,
"epochs": 10,
"optimizer": "sgd"
"model_parameters": {
"small": {
"conv": [
[256, 7, 3],
[256, 7, 3],
[256, 3, -1],
[256, 3, -1],
[256, 3, -1],
[256, 3, 3]
],
"fc": [1024, 1024]
}
}
},
"data": {
"text_column": "SentimentText",
"label_column": "Sentiment",
"max_length": 150,
"num_of_classes": 2,
"encoding": null,
"chunksize": 50000,
"max_rows": 100000,
"preprocessing_steps": [
"lower",
"remove_hashtags",
"remove_urls",
"remove_user_mentions"
]
},
"training": {
"batch_size": 128,
"learning_rate": 0.01,
"epochs": 10,
"optimizer": "sgd"
}
}
38 changes: 21 additions & 17 deletions predict.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,18 +6,19 @@

use_cuda = torch.cuda.is_available()


def predict(args):
model = CharacterLevelCNN(args, args.number_of_classes)
state = torch.load(args.model)
model.load_state_dict(state)
model.eval()

processed_input = utils.preprocess_input(args)
processed_input = torch.tensor(processed_input)
processed_input = processed_input.unsqueeze(0)
if use_cuda:
processed_input = processed_input.to('cuda')
model = model.to('cuda')
processed_input = processed_input.to("cuda")
model = model.to("cuda")
prediction = model(processed_input)
probabilities = F.softmax(prediction, dim=1)
probabilities = probabilities.detach().cpu().numpy()
Expand All @@ -26,22 +27,25 @@ def predict(args):

if __name__ == "__main__":
parser = argparse.ArgumentParser(
'Testing a pretrained Character Based CNN for text classification')
parser.add_argument('--model', type=str, help='path for pre-trained model')
parser.add_argument('--text', type=str,
default='I love pizza!', help='text string')
parser.add_argument('--steps', nargs="+", default=['lower'])
"Testing a pretrained Character Based CNN for text classification"
)
parser.add_argument("--model", type=str, help="path for pre-trained model")
parser.add_argument("--text", type=str, default="I love pizza!", help="text string")
parser.add_argument("--steps", nargs="+", default=["lower"])

# arguments needed for the predicition
parser.add_argument('--alphabet', type=str,
default="abcdefghijklmnopqrstuvwxyz0123456789-,;.!?:'\"/\\|_@#$%^&*~`+ =<>()[]{}")
parser.add_argument('--number_of_characters', type=int, default=69)
parser.add_argument('--extra_characters', type=str, default="éàèùâêîôûçëïü")
parser.add_argument('--max_length', type=int, default=300)
parser.add_argument('--number_of_classes', type=int, default=2)
parser.add_argument(
"--alphabet",
type=str,
default="abcdefghijklmnopqrstuvwxyz0123456789-,;.!?:'\"/\\|_@#$%^&*~`+ =<>()[]{}",
)
parser.add_argument("--number_of_characters", type=int, default=69)
parser.add_argument("--extra_characters", type=str, default="éàèùâêîôûçëïü")
parser.add_argument("--max_length", type=int, default=300)
parser.add_argument("--number_of_classes", type=int, default=2)

args = parser.parse_args()
prediction = predict(args)
print('input : {}'.format(args.text))
print('prediction : {}'.format(prediction))

print("input : {}".format(args.text))
print("prediction : {}".format(prediction))

0 comments on commit 5931976

Please sign in to comment.